from MSTH.configs.method_import import *
import numpy as np
import itertools
import random

method_configs: Dict[str, Union[TrainerConfig, VideoTrainerConfig]] = {}

method_configs["base"] = SpaceTimeHashingTrainerConfig(
    method_name="Spatial_Time_Hashing_With_Base",
    steps_per_eval_batch=1000,
    steps_per_save=20000,
    max_num_iterations=10000,
    mixed_precision=True,
    log_gradients=True,
    pipeline=SpaceTimePipelineConfig(
        datamanager=SpaceTimeDataManagerConfig(
            dataparser=VideoDataParserConfig(
                data=Path("/data/machine/data/flame_salmon_videos_2"),
                # data=Path("/data/machine/data/fit/videos_2"),
                downscale_factor=2,
                # scale_factor=1.0 / 2.0,
                scale_factor=0.5,
            ),
            train_num_rays_per_batch=16384,
            # eval_num_rays_per_batch=32768,
            camera_optimizer=CameraOptimizerConfig(mode="off"),
            use_uint8=True,
            # use_stratified_pixel_sampler=True,
            spatial_temporal_sampler="st",
            # n_time_for_dynamic=3,
            n_time_for_dynamic=lambda x: 1 if x < 1000 else 1 + 5 * np.sin((x - 1000) * np.pi / (2 * (10000 - 1000))),
            static_dynamic_sampling_ratio=50,
            static_dynamic_sampling_ratio_end=10,
            static_ratio_decay_total_steps=10000,
        ),
        model=DSpaceTimeHashingModelConfig(
            # distortion_loss_mult=0.0,
            max_res=(2048, 2048, 2048, 300),
            base_res=(16, 16, 16, 15),
            # num_proposal_samples_per_ray=(256, 96),
            num_proposal_samples_per_ray=(128,),
            num_nerf_samples_per_ray=48,
            proposal_weights_anneal_max_num_iters=5000,
            # proposal_weights_anneal_slope = 10.0,
            log2_hashmap_size_spatial=19,
            log2_hashmap_size_temporal=19,
            eval_num_rays_per_chunk=32768,
            mask_loss_mult=1.0,
            mst_mode="mt",
            mask_reso=(128, 128, 128),
            mask_log2_hash_size=21,
            mask_type="global",
            st_mlp_mode="shared",
            num_proposal_iterations=1,
            use_loss_static=False,
            render_static=False,
            interp="linear",
            mask_init_mean=-1,
            proposal_net_args_list=[
                {
                    "hidden_dim": 16,
                    "log2_hashmap_size_spatial": 17,
                    "log2_hashmap_size_temporal": 17,
                    "num_levels": 5,
                    "max_res": (128, 128, 128, 150),
                    "base_res": (16, 16, 16, 15),
                    "use_linear": False,
                    "mode": "mt",
                    "mask_reso": (64, 64, 64),
                    "mask_log2_hash_size": 18,
                    "mask_type": "global",
                    "st_mlp_mode": "shared",
                    "interp": "linear",
                    "mask_init_mean": -1,
                },
            ],
        ),
    ),
    optimizers={
        "proposal_networks": {
            "optimizer": AdamOptimizerConfig(lr=1e-2, eps=1e-15),
            "scheduler": ExponentialDecaySchedulerConfig(lr_final=5e-4, max_steps=15000),
        },
        "fields": {
            "optimizer": AdamOptimizerConfig(lr=1e-2, eps=1e-15),
            "scheduler": ExponentialDecaySchedulerConfig(lr_final=5e-4, max_steps=15000),
        },
    },
)

method_configs["base_nomask"] = SpaceTimeHashingTrainerConfig(
    method_name="Spatial_Time_Hashing_With_Base",
    steps_per_eval_batch=1000,
    steps_per_save=20000,
    max_num_iterations=10000,
    mixed_precision=True,
    log_gradients=True,
    pipeline=SpaceTimePipelineConfig(
        datamanager=SpaceTimeDataManagerConfig(
            dataparser=VideoDataParserConfig(
                data=Path("/data/machine/data/flame_salmon_videos_2"),
                # data=Path("/data/machine/data/fit/videos_2"),
                downscale_factor=2,
                # scale_factor=1.0 / 2.0,
                scale_factor=0.5,
            ),
            train_num_rays_per_batch=16384,
            # eval_num_rays_per_batch=32768,
            camera_optimizer=CameraOptimizerConfig(mode="off"),
            use_uint8=True,
            # use_stratified_pixel_sampler=True,
            spatial_temporal_sampler="st",
            # n_time_for_dynamic=3,
            n_time_for_dynamic=lambda x: 1 if x < 1000 else 1 + 5 * np.sin((x - 1000) * np.pi / (2 * (10000 - 1000))),
            static_dynamic_sampling_ratio=50,
            static_dynamic_sampling_ratio_end=10,
            static_ratio_decay_total_steps=10000,
        ),
        model=DSpaceTimeHashingModelConfig(
            # distortion_loss_mult=0.0,
            max_res=(2048, 2048, 2048, 300),
            base_res=(16, 16, 16, 15),
            # num_proposal_samples_per_ray=(256, 96),
            num_proposal_samples_per_ray=(128,),
            num_nerf_samples_per_ray=48,
            proposal_weights_anneal_max_num_iters=5000,
            # proposal_weights_anneal_slope = 10.0,
            log2_hashmap_size_spatial=19,
            log2_hashmap_size_temporal=19,
            eval_num_rays_per_chunk=32768,
            mask_loss_mult=1.0,
            mst_mode="mt",
            mask_reso=(128, 128, 128),
            mask_log2_hash_size=21,
            mask_type="global",
            st_mlp_mode="shared",
            num_proposal_iterations=1,
            use_loss_static=False,
            render_static=False,
            interp="linear",
            mask_init_mean=-1,
            ablation_add=True,
            proposal_net_args_list=[
                {
                    "hidden_dim": 16,
                    "log2_hashmap_size_spatial": 17,
                    "log2_hashmap_size_temporal": 17,
                    "num_levels": 5,
                    "max_res": (128, 128, 128, 150),
                    "base_res": (16, 16, 16, 15),
                    "use_linear": False,
                    "mode": "mt",
                    "mask_reso": (64, 64, 64),
                    "mask_log2_hash_size": 18,
                    "mask_type": "global",
                    "st_mlp_mode": "shared",
                    "interp": "linear",
                    "mask_init_mean": -1,
                    "ablation_add": True,
                },
            ],
        ),
    ),
    optimizers={
        "proposal_networks": {
            "optimizer": AdamOptimizerConfig(lr=1e-2, eps=1e-15),
            "scheduler": ExponentialDecaySchedulerConfig(lr_final=5e-4, max_steps=15000),
        },
        "fields": {
            "optimizer": AdamOptimizerConfig(lr=1e-2, eps=1e-15),
            "scheduler": ExponentialDecaySchedulerConfig(lr_final=5e-4, max_steps=15000),
        },
    },
)


method_configs["base_mst"] = SpaceTimeHashingTrainerConfig(
    method_name="Spatial_Time_Hashing_With_Base",
    steps_per_eval_batch=1000,
    steps_per_save=20000,
    max_num_iterations=10000,
    mixed_precision=True,
    log_gradients=True,
    pipeline=SpaceTimePipelineConfig(
        datamanager=SpaceTimeDataManagerConfig(
            dataparser=VideoDataParserConfig(
                data=Path("/data/machine/data/flame_salmon_videos_2"),
                # data=Path("/data/machine/data/fit/videos_2"),
                downscale_factor=2,
                # scale_factor=1.0 / 2.0,
                scale_factor=0.5,
            ),
            train_num_rays_per_batch=16384,
            # eval_num_rays_per_batch=32768,
            camera_optimizer=CameraOptimizerConfig(mode="off"),
            use_uint8=True,
            # use_stratified_pixel_sampler=True,
            spatial_temporal_sampler="st",
            # n_time_for_dynamic=3,
            n_time_for_dynamic=lambda x: 1 if x < 1000 else 1 + 5 * np.sin((x - 1000) * np.pi / (2 * (10000 - 1000))),
            static_dynamic_sampling_ratio=50,
            static_dynamic_sampling_ratio_end=10,
            static_ratio_decay_total_steps=10000,
        ),
        model=DSpaceTimeHashingModelConfig(
            # distortion_loss_mult=0.0,
            max_res=(2048, 2048, 2048, 300),
            base_res=(16, 16, 16, 15),
            # num_proposal_samples_per_ray=(256, 96),
            num_proposal_samples_per_ray=(128,),
            num_nerf_samples_per_ray=48,
            proposal_weights_anneal_max_num_iters=5000,
            # proposal_weights_anneal_slope = 10.0,
            log2_hashmap_size_spatial=19,
            log2_hashmap_size_temporal=19,
            eval_num_rays_per_chunk=32768,
            mask_loss_mult=1.0,
            mst_mode="mst",
            mask_reso=(128, 128, 128),
            mask_log2_hash_size=21,
            mask_type="global",
            st_mlp_mode="shared",
            num_proposal_iterations=1,
            use_loss_static=False,
            render_static=False,
            interp="linear",
            mask_init_mean=-1,
            proposal_net_args_list=[
                {
                    "hidden_dim": 16,
                    "log2_hashmap_size_spatial": 17,
                    "log2_hashmap_size_temporal": 17,
                    "num_levels": 5,
                    "max_res": (128, 128, 128, 150),
                    "base_res": (16, 16, 16, 15),
                    "use_linear": False,
                    "mode": "mst",
                    "mask_reso": (64, 64, 64),
                    "mask_log2_hash_size": 18,
                    "mask_type": "global",
                    "st_mlp_mode": "shared",
                    "interp": "linear",
                    "mask_init_mean": -1,
                },
            ],
        ),
    ),
    optimizers={
        "proposal_networks": {
            "optimizer": AdamOptimizerConfig(lr=1e-2, eps=1e-15),
            "scheduler": ExponentialDecaySchedulerConfig(lr_final=5e-4, max_steps=15000),
        },
        "fields": {
            "optimizer": AdamOptimizerConfig(lr=1e-2, eps=1e-15),
            "scheduler": ExponentialDecaySchedulerConfig(lr_final=5e-4, max_steps=15000),
        },
    },
)

method_configs["base_20"] = SpaceTimeHashingTrainerConfig(
    method_name="Spatial_Time_Hashing_With_Base",
    steps_per_eval_batch=1000,
    steps_per_save=20000,
    max_num_iterations=10000,
    mixed_precision=True,
    log_gradients=True,
    pipeline=SpaceTimePipelineConfig(
        datamanager=SpaceTimeDataManagerConfig(
            dataparser=VideoDataParserConfig(
                data=Path("/data/machine/data/flame_salmon_videos_2"),
                # data=Path("/data/machine/data/fit/videos_2"),
                downscale_factor=2,
                # scale_factor=1.0 / 2.0,
                scale_factor=0.5,
            ),
            train_num_rays_per_batch=16384,
            # eval_num_rays_per_batch=32768,
            camera_optimizer=CameraOptimizerConfig(mode="off"),
            use_uint8=True,
            # use_stratified_pixel_sampler=True,
            spatial_temporal_sampler="st",
            # n_time_for_dynamic=3,
            n_time_for_dynamic=lambda x: 1 if x < 1000 else 1 + 5 * np.sin((x - 1000) * np.pi / (2 * (10000 - 1000))),
            static_dynamic_sampling_ratio=50,
            static_dynamic_sampling_ratio_end=10,
            static_ratio_decay_total_steps=10000,
        ),
        model=DSpaceTimeHashingModelConfig(
            # distortion_loss_mult=0.0,
            max_res=(2048, 2048, 2048, 300),
            base_res=(16, 16, 16, 15),
            # num_proposal_samples_per_ray=(256, 96),
            num_proposal_samples_per_ray=(128,),
            num_nerf_samples_per_ray=48,
            proposal_weights_anneal_max_num_iters=5000,
            # proposal_weights_anneal_slope = 10.0,
            log2_hashmap_size_spatial=20,
            log2_hashmap_size_temporal=20,
            eval_num_rays_per_chunk=32768,
            mask_loss_mult=1.0,
            mst_mode="mt",
            mask_reso=(128, 128, 128),
            mask_log2_hash_size=21,
            mask_type="global",
            st_mlp_mode="shared",
            num_proposal_iterations=1,
            use_loss_static=False,
            render_static=False,
            interp="linear",
            mask_init_mean=-1,
            proposal_net_args_list=[
                {
                    "hidden_dim": 16,
                    "log2_hashmap_size_spatial": 18,
                    "log2_hashmap_size_temporal": 18,
                    "num_levels": 5,
                    "max_res": (128, 128, 128, 150),
                    "base_res": (16, 16, 16, 15),
                    "use_linear": False,
                    "mode": "mt",
                    "mask_reso": (64, 64, 64),
                    "mask_log2_hash_size": 18,
                    "mask_type": "global",
                    "st_mlp_mode": "shared",
                    "interp": "linear",
                    "mask_init_mean": -1,
                },
            ],
        ),
    ),
    optimizers={
        "proposal_networks": {
            "optimizer": AdamOptimizerConfig(lr=1e-2, eps=1e-15),
            "scheduler": ExponentialDecaySchedulerConfig(lr_final=5e-4, max_steps=15000),
        },
        "fields": {
            "optimizer": AdamOptimizerConfig(lr=1e-2, eps=1e-15),
            "scheduler": ExponentialDecaySchedulerConfig(lr_final=5e-4, max_steps=15000),
        },
    },
)

method_configs["base_32768"] = SpaceTimeHashingTrainerConfig(
    method_name="Spatial_Time_Hashing_With_Base",
    steps_per_eval_batch=1000,
    steps_per_save=20000,
    max_num_iterations=10000,
    mixed_precision=True,
    log_gradients=True,
    pipeline=SpaceTimePipelineConfig(
        datamanager=SpaceTimeDataManagerConfig(
            dataparser=VideoDataParserConfig(
                data=Path("/data/machine/data/flame_salmon_videos_2"),
                # data=Path("/data/machine/data/fit/videos_2"),
                downscale_factor=2,
                # scale_factor=1.0 / 2.0,
                scale_factor=0.5,
            ),
            train_num_rays_per_batch=32768,
            # eval_num_rays_per_batch=32768,
            camera_optimizer=CameraOptimizerConfig(mode="off"),
            use_uint8=True,
            # use_stratified_pixel_sampler=True,
            spatial_temporal_sampler="st",
            # n_time_for_dynamic=3,
            n_time_for_dynamic=lambda x: 1 if x < 1000 else 1 + 5 * np.sin((x - 1000) * np.pi / (2 * (10000 - 1000))),
            static_dynamic_sampling_ratio=50,
            static_dynamic_sampling_ratio_end=10,
            static_ratio_decay_total_steps=10000,
        ),
        model=DSpaceTimeHashingModelConfig(
            # distortion_loss_mult=0.0,
            max_res=(2048, 2048, 2048, 300),
            base_res=(16, 16, 16, 15),
            # num_proposal_samples_per_ray=(256, 96),
            num_proposal_samples_per_ray=(128,),
            num_nerf_samples_per_ray=48,
            proposal_weights_anneal_max_num_iters=5000,
            # proposal_weights_anneal_slope = 10.0,
            log2_hashmap_size_spatial=19,
            log2_hashmap_size_temporal=19,
            eval_num_rays_per_chunk=32768,
            mask_loss_mult=1.0,
            mst_mode="mt",
            mask_reso=(128, 128, 128),
            mask_log2_hash_size=21,
            mask_type="global",
            st_mlp_mode="shared",
            num_proposal_iterations=1,
            use_loss_static=False,
            render_static=False,
            interp="linear",
            mask_init_mean=-1,
            proposal_net_args_list=[
                {
                    "hidden_dim": 16,
                    "log2_hashmap_size_spatial": 17,
                    "log2_hashmap_size_temporal": 17,
                    "num_levels": 5,
                    "max_res": (128, 128, 128, 150),
                    "base_res": (16, 16, 16, 15),
                    "use_linear": False,
                    "mode": "mt",
                    "mask_reso": (64, 64, 64),
                    "mask_log2_hash_size": 18,
                    "mask_type": "global",
                    "st_mlp_mode": "shared",
                    "interp": "linear",
                    "mask_init_mean": -1,
                },
            ],
        ),
    ),
    optimizers={
        "proposal_networks": {
            "optimizer": AdamOptimizerConfig(lr=1e-2, eps=1e-15),
            "scheduler": ExponentialDecaySchedulerConfig(lr_final=5e-4, max_steps=15000),
        },
        "fields": {
            "optimizer": AdamOptimizerConfig(lr=1e-2, eps=1e-15),
            "scheduler": ExponentialDecaySchedulerConfig(lr_final=5e-4, max_steps=15000),
        },
    },
)

method_configs["base_20_mask_with_t"] = SpaceTimeHashingTrainerConfig(
    method_name="Spatial_Time_Hashing_With_Base",
    steps_per_eval_batch=1000,
    steps_per_save=20000,
    max_num_iterations=10000,
    mixed_precision=True,
    log_gradients=True,
    pipeline=SpaceTimePipelineConfig(
        datamanager=SpaceTimeDataManagerConfig(
            dataparser=VideoDataParserConfig(
                data=Path("/data/machine/data/flame_salmon_videos_2"),
                # data=Path("/data/machine/data/fit/videos_2"),
                downscale_factor=2,
                # scale_factor=1.0 / 2.0,
                scale_factor=0.5,
            ),
            train_num_rays_per_batch=16384,
            # eval_num_rays_per_batch=32768,
            camera_optimizer=CameraOptimizerConfig(mode="off"),
            use_uint8=True,
            # use_stratified_pixel_sampler=True,
            spatial_temporal_sampler="st",
            # n_time_for_dynamic=3,
            n_time_for_dynamic=lambda x: 1 if x < 1000 else 1 + 5 * np.sin((x - 1000) * np.pi / (2 * (10000 - 1000))),
            static_dynamic_sampling_ratio=50,
            static_dynamic_sampling_ratio_end=10,
            static_ratio_decay_total_steps=10000,
        ),
        model=DSpaceTimeHashingModelConfig(
            # distortion_loss_mult=0.0,
            max_res=(2048, 2048, 2048, 300),
            base_res=(16, 16, 16, 15),
            # num_proposal_samples_per_ray=(256, 96),
            num_proposal_samples_per_ray=(128,),
            num_nerf_samples_per_ray=48,
            proposal_weights_anneal_max_num_iters=5000,
            # proposal_weights_anneal_slope = 10.0,
            log2_hashmap_size_spatial=20,
            log2_hashmap_size_temporal=20,
            eval_num_rays_per_chunk=32768,
            mask_loss_mult=1.0,
            mst_mode="mt",
            mask_reso=(64, 64, 64, 16),
            mask_log2_hash_size=22,
            mask_type="global_multitime",
            st_mlp_mode="shared",
            num_proposal_iterations=1,
            use_loss_static=False,
            render_static=False,
            interp="linear",
            mask_init_mean=-1,
            proposal_net_args_list=[
                {
                    "hidden_dim": 16,
                    "log2_hashmap_size_spatial": 18,
                    "log2_hashmap_size_temporal": 18,
                    "num_levels": 5,
                    "max_res": (128, 128, 128, 150),
                    "base_res": (16, 16, 16, 15),
                    "use_linear": False,
                    "mode": "mt",
                    "mask_reso": (64, 64, 64, 16),
                    "mask_log2_hash_size": 22,
                    "mask_type": "global_multitime",
                    "st_mlp_mode": "shared",
                    "interp": "linear",
                    "mask_init_mean": -1,
                },
            ],
        ),
    ),
    optimizers={
        "proposal_networks": {
            "optimizer": AdamOptimizerConfig(lr=1e-2, eps=1e-15),
            "scheduler": ExponentialDecaySchedulerConfig(lr_final=5e-4, max_steps=15000),
        },
        "fields": {
            "optimizer": AdamOptimizerConfig(lr=1e-2, eps=1e-15),
            "scheduler": ExponentialDecaySchedulerConfig(lr_final=5e-4, max_steps=15000),
        },
    },
)

method_configs["base_it40000"] = SpaceTimeHashingTrainerConfig(
    method_name="Spatial_Time_Hashing_With_Base",
    steps_per_eval_batch=1000,
    steps_per_save=50000,
    max_num_iterations=40000,
    mixed_precision=True,
    log_gradients=True,
    pipeline=SpaceTimePipelineConfig(
        datamanager=SpaceTimeDataManagerConfig(
            dataparser=VideoDataParserConfig(
                data=Path("/data/machine/data/flame_salmon_videos_2"),
                # data=Path("/data/machine/data/fit/videos_2"),
                downscale_factor=2,
                # scale_factor=1.0 / 2.0,
                scale_factor=0.5,
            ),
            train_num_rays_per_batch=16384,
            # eval_num_rays_per_batch=32768,
            camera_optimizer=CameraOptimizerConfig(mode="off"),
            use_uint8=True,
            # use_stratified_pixel_sampler=True,
            spatial_temporal_sampler="st",
            # n_time_for_dynamic=3,
            n_time_for_dynamic=lambda x: 1 if x < 1000 else 1 + 2 * np.sin((x - 1000) * np.pi / (2 * (40000 - 1000))),
            static_dynamic_sampling_ratio=50,
            static_dynamic_sampling_ratio_end=10,
            static_ratio_decay_total_steps=500000,
        ),
        model=DSpaceTimeHashingModelConfig(
            # distortion_loss_mult=0.0,
            max_res=(2048, 2048, 2048, 300),
            base_res=(16, 16, 16, 15),
            # num_proposal_samples_per_ray=(256, 96),
            num_proposal_samples_per_ray=(128,),
            num_nerf_samples_per_ray=48,
            proposal_weights_anneal_max_num_iters=5000,
            # proposal_weights_anneal_slope = 10.0,
            log2_hashmap_size_spatial=19,
            log2_hashmap_size_temporal=19,
            eval_num_rays_per_chunk=32768,
            mask_loss_mult=1.0,
            mst_mode="mt",
            mask_reso=(128, 128, 128),
            mask_log2_hash_size=21,
            mask_type="global",
            st_mlp_mode="shared",
            num_proposal_iterations=1,
            use_loss_static=False,
            render_static=False,
            interp="linear",
            mask_init_mean=-1,
            proposal_net_args_list=[
                {
                    "hidden_dim": 16,
                    "log2_hashmap_size_spatial": 18,
                    "log2_hashmap_size_temporal": 18,
                    "num_levels": 5,
                    "max_res": (128, 128, 128, 150),
                    "base_res": (16, 16, 16, 15),
                    "use_linear": False,
                    "mode": "mt",
                    "mask_reso": (64, 64, 64),
                    "mask_log2_hash_size": 18,
                    "mask_type": "global",
                    "st_mlp_mode": "shared",
                    "interp": "linear",
                    "mask_init_mean": -1,
                },
            ],
        ),
    ),
    optimizers={
        "proposal_networks": {
            "optimizer": AdamOptimizerConfig(lr=1e-2, eps=1e-15),
            "scheduler": ExponentialDecaySchedulerConfig(lr_final=2e-4, max_steps=35000),
        },
        "fields": {
            "optimizer": AdamOptimizerConfig(lr=1e-2, eps=1e-15),
            "scheduler": ExponentialDecaySchedulerConfig(lr_final=2e-4, max_steps=35000),
        },
    },
)

method_configs["base_it40000_st_4"] = SpaceTimeHashingTrainerConfig(
    method_name="Spatial_Time_Hashing_With_Base",
    steps_per_eval_batch=1000,
    steps_per_save=50000,
    max_num_iterations=40000,
    mixed_precision=True,
    log_gradients=True,
    pipeline=SpaceTimePipelineConfig(
        datamanager=SpaceTimeDataManagerConfig(
            dataparser=VideoDataParserConfig(
                data=Path("/data/machine/data/flame_salmon_videos_2"),
                # data=Path("/data/machine/data/fit/videos_2"),
                downscale_factor=2,
                # scale_factor=1.0 / 2.0,
                scale_factor=0.5,
            ),
            train_num_rays_per_batch=16384,
            # eval_num_rays_per_batch=32768,
            camera_optimizer=CameraOptimizerConfig(mode="off"),
            use_uint8=True,
            # use_stratified_pixel_sampler=True,
            spatial_temporal_sampler="st",
            # n_time_for_dynamic=3,
            n_time_for_dynamic=lambda x: 1 if x < 1000 else 1 + 5 * np.sin((x - 1000) * np.pi / (2 * (40000 - 1000))),
            static_dynamic_sampling_ratio=50,
            static_dynamic_sampling_ratio_end=10,
            static_ratio_decay_total_steps=500000,
        ),
        model=DSpaceTimeHashingModelConfig(
            # distortion_loss_mult=0.0,
            max_res=(2048, 2048, 2048, 300),
            base_res=(16, 16, 16, 15),
            # num_proposal_samples_per_ray=(256, 96),
            num_proposal_samples_per_ray=(128,),
            num_nerf_samples_per_ray=48,
            proposal_weights_anneal_max_num_iters=5000,
            # proposal_weights_anneal_slope = 10.0,
            log2_hashmap_size_spatial=19,
            log2_hashmap_size_temporal=19,
            eval_num_rays_per_chunk=32768,
            mask_loss_mult=1.0,
            mst_mode="mt",
            mask_reso=(128, 128, 128),
            mask_log2_hash_size=21,
            mask_type="global",
            st_mlp_mode="shared",
            num_proposal_iterations=1,
            use_loss_static=False,
            render_static=False,
            interp="linear",
            mask_init_mean=-1,
            proposal_net_args_list=[
                {
                    "hidden_dim": 16,
                    "log2_hashmap_size_spatial": 18,
                    "log2_hashmap_size_temporal": 18,
                    "num_levels": 5,
                    "max_res": (128, 128, 128, 150),
                    "base_res": (16, 16, 16, 15),
                    "use_linear": False,
                    "mode": "mt",
                    "mask_reso": (64, 64, 64),
                    "mask_log2_hash_size": 18,
                    "mask_type": "global",
                    "st_mlp_mode": "shared",
                    "interp": "linear",
                    "mask_init_mean": -1,
                },
            ],
        ),
    ),
    optimizers={
        "proposal_networks": {
            "optimizer": AdamOptimizerConfig(lr=1e-2, eps=1e-15),
            "scheduler": ExponentialDecaySchedulerConfig(lr_final=2e-4, max_steps=35000),
        },
        "fields": {
            "optimizer": AdamOptimizerConfig(lr=1e-2, eps=1e-15),
            "scheduler": ExponentialDecaySchedulerConfig(lr_final=2e-4, max_steps=35000),
        },
    },
)


method_configs["base_it40000_st_5"] = SpaceTimeHashingTrainerConfig(
    method_name="Spatial_Time_Hashing_With_Base",
    steps_per_eval_batch=1000,
    steps_per_save=50000,
    max_num_iterations=40000,
    mixed_precision=True,
    log_gradients=True,
    pipeline=SpaceTimePipelineConfig(
        datamanager=SpaceTimeDataManagerConfig(
            dataparser=VideoDataParserConfig(
                data=Path("/data/machine/data/flame_salmon_videos_2"),
                # data=Path("/data/machine/data/fit/videos_2"),
                downscale_factor=2,
                # scale_factor=1.0 / 2.0,
                scale_factor=0.5,
            ),
            train_num_rays_per_batch=16384,
            # eval_num_rays_per_batch=32768,
            camera_optimizer=CameraOptimizerConfig(mode="off"),
            use_uint8=True,
            # use_stratified_pixel_sampler=True,
            spatial_temporal_sampler="st",
            # n_time_for_dynamic=3,
            n_time_for_dynamic=lambda x: 1 if x < 1000 else 1 + 5 * np.sin((x - 1000) * np.pi / (2 * (40000 - 1000))),
            static_dynamic_sampling_ratio=50,
            static_dynamic_sampling_ratio_end=10,
            static_ratio_decay_total_steps=40000,
        ),
        model=DSpaceTimeHashingModelConfig(
            # distortion_loss_mult=0.0,
            max_res=(2048, 2048, 2048, 300),
            base_res=(16, 16, 16, 15),
            # num_proposal_samples_per_ray=(256, 96),
            num_proposal_samples_per_ray=(128,),
            num_nerf_samples_per_ray=48,
            proposal_weights_anneal_max_num_iters=5000,
            # proposal_weights_anneal_slope = 10.0,
            log2_hashmap_size_spatial=19,
            log2_hashmap_size_temporal=19,
            eval_num_rays_per_chunk=32768,
            mask_loss_mult=1.0,
            mst_mode="mt",
            mask_reso=(128, 128, 128),
            mask_log2_hash_size=21,
            mask_type="global",
            st_mlp_mode="shared",
            num_proposal_iterations=1,
            use_loss_static=False,
            render_static=False,
            interp="linear",
            mask_init_mean=-1,
            proposal_net_args_list=[
                {
                    "hidden_dim": 16,
                    "log2_hashmap_size_spatial": 18,
                    "log2_hashmap_size_temporal": 18,
                    "num_levels": 5,
                    "max_res": (128, 128, 128, 150),
                    "base_res": (16, 16, 16, 15),
                    "use_linear": False,
                    "mode": "mt",
                    "mask_reso": (64, 64, 64),
                    "mask_log2_hash_size": 18,
                    "mask_type": "global",
                    "st_mlp_mode": "shared",
                    "interp": "linear",
                    "mask_init_mean": -1,
                },
            ],
        ),
    ),
    optimizers={
        "proposal_networks": {
            "optimizer": AdamOptimizerConfig(lr=1e-2, eps=1e-15),
            "scheduler": ExponentialDecaySchedulerConfig(lr_final=2e-4, max_steps=35000),
        },
        "fields": {
            "optimizer": AdamOptimizerConfig(lr=1e-2, eps=1e-15),
            "scheduler": ExponentialDecaySchedulerConfig(lr_final=2e-4, max_steps=35000),
        },
    },
)


method_configs["base_it40000_st_4_nomask"] = SpaceTimeHashingTrainerConfig(
    method_name="Spatial_Time_Hashing_With_Base",
    steps_per_eval_batch=1000,
    steps_per_save=50000,
    max_num_iterations=40000,
    mixed_precision=True,
    log_gradients=True,
    pipeline=SpaceTimePipelineConfig(
        datamanager=SpaceTimeDataManagerConfig(
            dataparser=VideoDataParserConfig(
                data=Path("/data/machine/data/flame_salmon_videos_2"),
                # data=Path("/data/machine/data/fit/videos_2"),
                downscale_factor=2,
                # scale_factor=1.0 / 2.0,
                scale_factor=0.5,
            ),
            train_num_rays_per_batch=16384,
            # eval_num_rays_per_batch=32768,
            camera_optimizer=CameraOptimizerConfig(mode="off"),
            use_uint8=True,
            # use_stratified_pixel_sampler=True,
            spatial_temporal_sampler="st",
            # n_time_for_dynamic=3,
            n_time_for_dynamic=lambda x: 1 if x < 1000 else 1 + 5 * np.sin((x - 1000) * np.pi / (2 * (40000 - 1000))),
            static_dynamic_sampling_ratio=50,
            static_dynamic_sampling_ratio_end=10,
            static_ratio_decay_total_steps=500000,
        ),
        model=DSpaceTimeHashingModelConfig(
            # distortion_loss_mult=0.0,
            max_res=(2048, 2048, 2048, 300),
            base_res=(16, 16, 16, 15),
            # num_proposal_samples_per_ray=(256, 96),
            num_proposal_samples_per_ray=(128,),
            num_nerf_samples_per_ray=48,
            proposal_weights_anneal_max_num_iters=5000,
            # proposal_weights_anneal_slope = 10.0,
            log2_hashmap_size_spatial=19,
            log2_hashmap_size_temporal=19,
            eval_num_rays_per_chunk=32768,
            mask_loss_mult=0.0,
            mst_mode="mt",
            mask_reso=(128, 128, 128),
            mask_log2_hash_size=21,
            mask_type="global",
            st_mlp_mode="shared",
            num_proposal_iterations=1,
            use_loss_static=False,
            render_static=False,
            interp="linear",
            mask_init_mean=-1,
            ablation_add=True,
            proposal_net_args_list=[
                {
                    "hidden_dim": 16,
                    "log2_hashmap_size_spatial": 18,
                    "log2_hashmap_size_temporal": 18,
                    "num_levels": 5,
                    "max_res": (128, 128, 128, 150),
                    "base_res": (16, 16, 16, 15),
                    "use_linear": False,
                    "mode": "mt",
                    "mask_reso": (64, 64, 64),
                    "mask_log2_hash_size": 18,
                    "mask_type": "global",
                    "st_mlp_mode": "shared",
                    "interp": "linear",
                    "mask_init_mean": -1,
                    "ablation_add": True,
                },
            ],
        ),
    ),
    optimizers={
        "proposal_networks": {
            "optimizer": AdamOptimizerConfig(lr=1e-2, eps=1e-15),
            "scheduler": ExponentialDecaySchedulerConfig(lr_final=2e-4, max_steps=35000),
        },
        "fields": {
            "optimizer": AdamOptimizerConfig(lr=1e-2, eps=1e-15),
            "scheduler": ExponentialDecaySchedulerConfig(lr_final=2e-4, max_steps=35000),
        },
    },
)

method_configs["base_lr2"] = SpaceTimeHashingTrainerConfig(
    method_name="Spatial_Time_Hashing_With_Base",
    steps_per_eval_batch=1000,
    steps_per_save=20000,
    max_num_iterations=10000,
    mixed_precision=True,
    log_gradients=True,
    pipeline=SpaceTimePipelineConfig(
        datamanager=SpaceTimeDataManagerConfig(
            dataparser=VideoDataParserConfig(
                data=Path("/data/machine/data/flame_salmon_videos_2"),
                # data=Path("/data/machine/data/fit/videos_2"),
                downscale_factor=2,
                # scale_factor=1.0 / 2.0,
                scale_factor=0.5,
            ),
            train_num_rays_per_batch=16384,
            # eval_num_rays_per_batch=32768,
            camera_optimizer=CameraOptimizerConfig(mode="off"),
            use_uint8=True,
            # use_stratified_pixel_sampler=True,
            spatial_temporal_sampler="st",
            # n_time_for_dynamic=3,
            n_time_for_dynamic=lambda x: 1 if x < 1000 else 1 + 5 * np.sin((x - 1000) * np.pi / (2 * (10000 - 1000))),
            static_dynamic_sampling_ratio=50,
            static_dynamic_sampling_ratio_end=10,
            static_ratio_decay_total_steps=10000,
        ),
        model=DSpaceTimeHashingModelConfig(
            # distortion_loss_mult=0.0,
            max_res=(2048, 2048, 2048, 300),
            base_res=(16, 16, 16, 15),
            # num_proposal_samples_per_ray=(256, 96),
            num_proposal_samples_per_ray=(128,),
            num_nerf_samples_per_ray=48,
            proposal_weights_anneal_max_num_iters=5000,
            # proposal_weights_anneal_slope = 10.0,
            log2_hashmap_size_spatial=19,
            log2_hashmap_size_temporal=19,
            eval_num_rays_per_chunk=32768,
            mask_loss_mult=1.0,
            mst_mode="mt",
            mask_reso=(128, 128, 128),
            mask_log2_hash_size=21,
            mask_type="global",
            st_mlp_mode="shared",
            num_proposal_iterations=1,
            use_loss_static=False,
            render_static=False,
            interp="linear",
            mask_init_mean=-1,
            proposal_net_args_list=[
                {
                    "hidden_dim": 16,
                    "log2_hashmap_size_spatial": 17,
                    "log2_hashmap_size_temporal": 17,
                    "num_levels": 5,
                    "max_res": (128, 128, 128, 150),
                    "base_res": (16, 16, 16, 15),
                    "use_linear": False,
                    "mode": "mt",
                    "mask_reso": (64, 64, 64),
                    "mask_log2_hash_size": 18,
                    "mask_type": "global",
                    "st_mlp_mode": "shared",
                    "interp": "linear",
                    "mask_init_mean": -1,
                },
            ],
        ),
    ),
    optimizers={
        "proposal_networks": {
            "optimizer": AdamOptimizerConfig(lr=1.5e-2, eps=1e-15),
            "scheduler": ExponentialDecaySchedulerConfig(lr_final=5e-4, max_steps=15000),
        },
        "fields": {
            "optimizer": AdamOptimizerConfig(lr=1.5e-2, eps=1e-15),
            "scheduler": ExponentialDecaySchedulerConfig(lr_final=5e-4, max_steps=15000),
        },
    },
)

method_configs["base_lr3"] = SpaceTimeHashingTrainerConfig(
    method_name="Spatial_Time_Hashing_With_Base",
    steps_per_eval_batch=1000,
    steps_per_save=20000,
    max_num_iterations=10000,
    mixed_precision=True,
    log_gradients=True,
    pipeline=SpaceTimePipelineConfig(
        datamanager=SpaceTimeDataManagerConfig(
            dataparser=VideoDataParserConfig(
                data=Path("/data/machine/data/flame_salmon_videos_2"),
                # data=Path("/data/machine/data/fit/videos_2"),
                downscale_factor=2,
                # scale_factor=1.0 / 2.0,
                scale_factor=0.5,
            ),
            train_num_rays_per_batch=16384,
            # eval_num_rays_per_batch=32768,
            camera_optimizer=CameraOptimizerConfig(mode="off"),
            use_uint8=True,
            # use_stratified_pixel_sampler=True,
            spatial_temporal_sampler="st",
            # n_time_for_dynamic=3,
            n_time_for_dynamic=lambda x: 1 if x < 1000 else 1 + 5 * np.sin((x - 1000) * np.pi / (2 * (10000 - 1000))),
            static_dynamic_sampling_ratio=50,
            static_dynamic_sampling_ratio_end=10,
            static_ratio_decay_total_steps=10000,
        ),
        model=DSpaceTimeHashingModelConfig(
            # distortion_loss_mult=0.0,
            max_res=(2048, 2048, 2048, 300),
            base_res=(16, 16, 16, 15),
            # num_proposal_samples_per_ray=(256, 96),
            num_proposal_samples_per_ray=(128,),
            num_nerf_samples_per_ray=48,
            proposal_weights_anneal_max_num_iters=5000,
            # proposal_weights_anneal_slope = 10.0,
            log2_hashmap_size_spatial=19,
            log2_hashmap_size_temporal=19,
            eval_num_rays_per_chunk=32768,
            mask_loss_mult=1.0,
            mst_mode="mt",
            mask_reso=(128, 128, 128),
            mask_log2_hash_size=21,
            mask_type="global",
            st_mlp_mode="shared",
            num_proposal_iterations=1,
            use_loss_static=False,
            render_static=False,
            interp="linear",
            mask_init_mean=-1,
            proposal_net_args_list=[
                {
                    "hidden_dim": 16,
                    "log2_hashmap_size_spatial": 17,
                    "log2_hashmap_size_temporal": 17,
                    "num_levels": 5,
                    "max_res": (128, 128, 128, 150),
                    "base_res": (16, 16, 16, 15),
                    "use_linear": False,
                    "mode": "mt",
                    "mask_reso": (64, 64, 64),
                    "mask_log2_hash_size": 18,
                    "mask_type": "global",
                    "st_mlp_mode": "shared",
                    "interp": "linear",
                    "mask_init_mean": -1,
                },
            ],
        ),
    ),
    optimizers={
        "proposal_networks": {
            "optimizer": AdamOptimizerConfig(lr=7.5e-3, eps=1e-15),
            "scheduler": ExponentialDecaySchedulerConfig(lr_final=5e-4, max_steps=15000),
        },
        "fields": {
            "optimizer": AdamOptimizerConfig(lr=7.5e-3, eps=1e-15),
            "scheduler": ExponentialDecaySchedulerConfig(lr_final=5e-4, max_steps=15000),
        },
    },
)


# base_method = method_configs["base_it40000_st_4"]


# def setp(exps):
#     def setfunc(x, v):
#         if isinstance(exps, (tuple, list)):
#             for exp in exps:
#                 command = "x." + exp + "=v"
#                 exec(command)
#         else:
#             command = "x." + exps + "=v"
#             exec(command)

#     return setfunc


# set_functions = {
#     "n_time": setp("pipeline.datamanager.n_time_for_dynamic"),
#     "sampling_ratio_start": setp("pipeline.datamanager.static_dynamic_sampling_ratio"),
#     "sampling_ratio_end": setp("pipeline.datamanager.static_dynamic_sampling_ratio_end"),
#     "sampling_ratio_decay": setp("pipeline.datamanager.static_ratio_decay_total_steps"),
#     "mask_loss_mult": setp("pipeline.model.mask_loss_mult"),
#     "mask_init_mean": setp(
#         ["pipeline.model.mask_init_mean", "pipeline.model.proposal_net_args_list[0]['mask_init_mean']"]
#     ),
# }

# potential_values = {
#     # "n_time": [
#     # lambda x: 1 if x < 1000 else 1 + 5 * np.sin((x - 1000) * np.pi / (2 * (40000 - 1000))),
#     # ],
#     # "sampling_ratio_start": [50, 40, 30, 20, 10],
#     # "sampling_ratio_end": [40, 30, 20, 10, 1],
#     "mask_init_mean": [
#         -1,
#         -0.5,
#         0.0,
#         0.5,
#         1,
#     ],
#     "mask_loss_mult": [
#         0.01,
#         0.1,
#         0.5,
#         1,
#     ],
# }

# all_hyper_parameter_key = potential_values.keys()
# all_hyper_parameter_value = [potential_values[k] for k in all_hyper_parameter_key]
# all_specs = list(itertools.product(*all_hyper_parameter_value))
# all_specs = [{k: v for k, v in zip(all_hyper_parameter_key, spec)} for spec in all_specs]
# random.shuffle(all_specs)
# print("==== ALL SPECS ====")
# print(all_specs)
# all_hyper_parameter_key = potential_values.keys()
# all_hyper_parameter_value = [potential_values[k] for k in all_hyper_parameter_key]
# all_specs = list(itertools.product(*all_hyper_parameter_value))
# all_specs = [{k: v for k, v in zip(all_hyper_parameter_key, spec)} for spec in all_specs]
# random.shuffle(all_specs)
# print("==== ALL SPECS ====")
# print(all_specs)

# for i, spec in enumerate(all_specs):
#     method_configs[f"anoynmous_method_{i}"] = copy.deepcopy(base_method)
#     for k, v in spec.items():
#         set_functions[k](method_configs[f"anoynmous_method_{i}"], v)
#     print(method_configs[f"anoynmous_method_{i}"])
# # set_sampling_ratio_decay(anoynmous_method, anoynmous_method.max_num_iterations)
# for i, spec in enumerate(all_specs):
#     method_configs[f"anoynmous_method_{i}"] = copy.deepcopy(base_method)
#     for k, v in spec.items():
#         set_functions[k](method_configs[f"anoynmous_method_{i}"], v)
#     print(method_configs[f"anoynmous_method_{i}"])
# # set_sampling_ratio_decay(anoynmous_method, anoynmous_method.max_num_iterations)


method_configs["base_it80000_base"] = SpaceTimeHashingTrainerConfig(
    method_name="Spatial_Time_Hashing_With_Base",
    steps_per_eval_batch=1000,
    steps_per_save=50000,
    max_num_iterations=80000,
    mixed_precision=True,
    log_gradients=True,
    pipeline=SpaceTimePipelineConfig(
        datamanager=SpaceTimeDataManagerConfig(
            dataparser=VideoDataParserConfig(
                data=Path("/data/machine/data/flame_salmon_videos_2"),
                # data=Path("/data/machine/data/fit/videos_2"),
                downscale_factor=2,
                # scale_factor=1.0 / 2.0,
                scale_factor=0.5,
            ),
            train_num_rays_per_batch=16384,
            # eval_num_rays_per_batch=32768,
            camera_optimizer=CameraOptimizerConfig(mode="off"),
            use_uint8=True,
            # use_stratified_pixel_sampler=True,
            spatial_temporal_sampler="st",
            # n_time_for_dynamic=3,
            n_time_for_dynamic=lambda x: 1 if x < 1000 else 1 + 5 * np.sin((x - 1000) * np.pi / (2 * (80000 - 1000))),
            static_dynamic_sampling_ratio=50,
            static_dynamic_sampling_ratio_end=10,
            static_ratio_decay_total_steps=500000,
        ),
        model=DSpaceTimeHashingModelConfig(
            # distortion_loss_mult=0.0,
            max_res=(2048, 2048, 2048, 300),
            base_res=(16, 16, 16, 15),
            # num_proposal_samples_per_ray=(256, 96),
            num_proposal_samples_per_ray=(128,),
            num_nerf_samples_per_ray=48,
            proposal_weights_anneal_max_num_iters=5000,
            # proposal_weights_anneal_slope = 10.0,
            log2_hashmap_size_spatial=19,
            log2_hashmap_size_temporal=19,
            eval_num_rays_per_chunk=32768,
            mask_loss_mult=0.1,
            mst_mode="mt",
            mask_reso=(128, 128, 128),
            mask_log2_hash_size=21,
            mask_type="global",
            st_mlp_mode="shared",
            num_proposal_iterations=1,
            use_loss_static=False,
            render_static=False,
            interp="linear",
            mask_init_mean=-1,
            proposal_net_args_list=[
                {
                    "hidden_dim": 16,
                    "log2_hashmap_size_spatial": 18,
                    "log2_hashmap_size_temporal": 18,
                    "num_levels": 5,
                    "max_res": (128, 128, 128, 150),
                    "base_res": (16, 16, 16, 15),
                    "use_linear": False,
                    "mode": "mt",
                    "mask_reso": (64, 64, 64),
                    "mask_log2_hash_size": 18,
                    "mask_type": "global",
                    "st_mlp_mode": "shared",
                    "interp": "linear",
                    "mask_init_mean": -1,
                },
            ],
        ),
    ),
    optimizers={
        "proposal_networks": {
            "optimizer": AdamOptimizerConfig(lr=1e-2, eps=1e-15),
            "scheduler": ExponentialDecaySchedulerConfig(lr_final=2e-4, max_steps=75000),
        },
        "fields": {
            "optimizer": AdamOptimizerConfig(lr=1e-2, eps=1e-15),
            "scheduler": ExponentialDecaySchedulerConfig(lr_final=2e-4, max_steps=75000),
        },
    },
)

method_configs["base_it40000_base"] = SpaceTimeHashingTrainerConfig(
    method_name="Spatial_Time_Hashing_With_Base",
    steps_per_eval_batch=1000,
    steps_per_save=50000,
    max_num_iterations=40000,
    mixed_precision=True,
    log_gradients=True,
    pipeline=SpaceTimePipelineConfig(
        datamanager=SpaceTimeDataManagerConfig(
            dataparser=VideoDataParserConfig(
                data=Path("/data/machine/data/flame_salmon_videos_2"),
                # data=Path("/data/machine/data/fit/videos_2"),
                downscale_factor=2,
                # scale_factor=1.0 / 2.0,
                scale_factor=0.5,
            ),
            train_num_rays_per_batch=16384,
            # eval_num_rays_per_batch=32768,
            camera_optimizer=CameraOptimizerConfig(mode="off"),
            use_uint8=True,
            # use_stratified_pixel_sampler=True,
            spatial_temporal_sampler="st",
            # n_time_for_dynamic=3,
            n_time_for_dynamic=lambda x: 1 if x < 1000 else 1 + 5 * np.sin((x - 1000) * np.pi / (2 * (40000 - 1000))),
            static_dynamic_sampling_ratio=50,
            static_dynamic_sampling_ratio_end=10,
            static_ratio_decay_total_steps=500000,
        ),
        model=DSpaceTimeHashingModelConfig(
            # distortion_loss_mult=0.0,
            max_res=(2048, 2048, 2048, 300),
            base_res=(16, 16, 16, 15),
            # num_proposal_samples_per_ray=(256, 96),
            num_proposal_samples_per_ray=(128,),
            num_nerf_samples_per_ray=48,
            proposal_weights_anneal_max_num_iters=5000,
            # proposal_weights_anneal_slope = 10.0,
            log2_hashmap_size_spatial=19,
            log2_hashmap_size_temporal=19,
            eval_num_rays_per_chunk=32768,
            mask_loss_mult=0.1,
            mst_mode="mt",
            mask_reso=(128, 128, 128),
            mask_log2_hash_size=21,
            mask_type="global",
            st_mlp_mode="shared",
            num_proposal_iterations=1,
            use_loss_static=False,
            render_static=False,
            interp="linear",
            mask_init_mean=-1,
            proposal_net_args_list=[
                {
                    "hidden_dim": 16,
                    "log2_hashmap_size_spatial": 18,
                    "log2_hashmap_size_temporal": 18,
                    "num_levels": 5,
                    "max_res": (128, 128, 128, 150),
                    "base_res": (16, 16, 16, 15),
                    "use_linear": False,
                    "mode": "mt",
                    "mask_reso": (64, 64, 64),
                    "mask_log2_hash_size": 18,
                    "mask_type": "global",
                    "st_mlp_mode": "shared",
                    "interp": "linear",
                    "mask_init_mean": -1,
                },
            ],
        ),
    ),
    optimizers={
        "proposal_networks": {
            "optimizer": AdamOptimizerConfig(lr=1e-2, eps=1e-15),
            "scheduler": ExponentialDecaySchedulerConfig(lr_final=1e-4, max_steps=40000),
        },
        "fields": {
            "optimizer": AdamOptimizerConfig(lr=1e-2, eps=1e-15),
            "scheduler": ExponentialDecaySchedulerConfig(lr_final=1e-4, max_steps=40000),
        },
    },
)

method_configs["base_it40000_base_lrlonger"] = SpaceTimeHashingTrainerConfig(
    method_name="Spatial_Time_Hashing_With_Base",
    steps_per_eval_batch=1000,
    steps_per_save=50000,
    max_num_iterations=40000,
    mixed_precision=True,
    log_gradients=True,
    pipeline=SpaceTimePipelineConfig(
        datamanager=SpaceTimeDataManagerConfig(
            dataparser=VideoDataParserConfig(
                data=Path("/data/machine/data/flame_salmon_videos_2"),
                # data=Path("/data/machine/data/fit/videos_2"),
                downscale_factor=2,
                # scale_factor=1.0 / 2.0,
                scale_factor=0.5,
            ),
            train_num_rays_per_batch=16384,
            # eval_num_rays_per_batch=32768,
            camera_optimizer=CameraOptimizerConfig(mode="off"),
            use_uint8=True,
            # use_stratified_pixel_sampler=True,
            spatial_temporal_sampler="st",
            # n_time_for_dynamic=3,
            n_time_for_dynamic=lambda x: 1 if x < 1000 else 1 + 5 * np.sin((x - 1000) * np.pi / (2 * (40000 - 1000))),
            static_dynamic_sampling_ratio=50,
            static_dynamic_sampling_ratio_end=10,
            static_ratio_decay_total_steps=500000,
        ),
        model=DSpaceTimeHashingModelConfig(
            # distortion_loss_mult=0.0,
            max_res=(2048, 2048, 2048, 300),
            base_res=(16, 16, 16, 15),
            # num_proposal_samples_per_ray=(256, 96),
            num_proposal_samples_per_ray=(128,),
            num_nerf_samples_per_ray=48,
            proposal_weights_anneal_max_num_iters=5000,
            # proposal_weights_anneal_slope = 10.0,
            log2_hashmap_size_spatial=19,
            log2_hashmap_size_temporal=19,
            eval_num_rays_per_chunk=32768,
            mask_loss_mult=0.1,
            mst_mode="mt",
            mask_reso=(128, 128, 128),
            mask_log2_hash_size=21,
            mask_type="global",
            st_mlp_mode="shared",
            num_proposal_iterations=1,
            use_loss_static=False,
            render_static=False,
            interp="linear",
            mask_init_mean=-1,
            proposal_net_args_list=[
                {
                    "hidden_dim": 16,
                    "log2_hashmap_size_spatial": 18,
                    "log2_hashmap_size_temporal": 18,
                    "num_levels": 5,
                    "max_res": (128, 128, 128, 150),
                    "base_res": (16, 16, 16, 15),
                    "use_linear": False,
                    "mode": "mt",
                    "mask_reso": (64, 64, 64),
                    "mask_log2_hash_size": 18,
                    "mask_type": "global",
                    "st_mlp_mode": "shared",
                    "interp": "linear",
                    "mask_init_mean": -1,
                },
            ],
        ),
    ),
    optimizers={
        "proposal_networks": {
            "optimizer": AdamOptimizerConfig(lr=1e-2, eps=1e-15),
            "scheduler": ExponentialDecaySchedulerConfig(lr_final=1e-4, max_steps=60000),
        },
        "fields": {
            "optimizer": AdamOptimizerConfig(lr=1e-2, eps=1e-15),
            "scheduler": ExponentialDecaySchedulerConfig(lr_final=1e-4, max_steps=60000),
        },
    },
)

method_configs["base_it40000_base_mst_nt50"] = SpaceTimeHashingTrainerConfig(
    method_name="Spatial_Time_Hashing_With_Base",
    steps_per_eval_batch=1000,
    steps_per_save=50000,
    max_num_iterations=40000,
    mixed_precision=True,
    log_gradients=True,
    pipeline=SpaceTimePipelineConfig(
        datamanager=SpaceTimeDataManagerConfig(
            dataparser=VideoDataParserConfig(
                data=Path("/data/machine/data/flame_salmon_videos_2"),
                # data=Path("/data/machine/data/fit/videos_2"),
                downscale_factor=2,
                # scale_factor=1.0 / 2.0,
                scale_factor=0.5,
            ),
            train_num_rays_per_batch=16384,
            # eval_num_rays_per_batch=32768,
            camera_optimizer=CameraOptimizerConfig(mode="off"),
            use_uint8=True,
            # use_stratified_pixel_sampler=True,
            spatial_temporal_sampler="st",
            # n_time_for_dynamic=3,
            # n_time_for_dynamic=lambda x: 1 if x < 1000 else 1 + 5 * np.sin((x - 1000) * np.pi / (2 * (40000 - 1000))),
            n_time_for_dynamic=lambda x: 1 if x < 1000 else 1 + 50 * np.sin((x - 1000) * np.pi / (2 * (40000 - 1000))),
            static_dynamic_sampling_ratio=50,
            static_dynamic_sampling_ratio_end=10,
            static_ratio_decay_total_steps=500000,
        ),
        model=DSpaceTimeHashingModelConfig(
            # distortion_loss_mult=0.0,
            max_res=(2048, 2048, 2048, 300),
            base_res=(16, 16, 16, 15),
            # num_proposal_samples_per_ray=(256, 96),
            num_proposal_samples_per_ray=(128,),
            num_nerf_samples_per_ray=48,
            proposal_weights_anneal_max_num_iters=5000,
            # proposal_weights_anneal_slope = 10.0,
            log2_hashmap_size_spatial=19,
            log2_hashmap_size_temporal=19,
            eval_num_rays_per_chunk=32768,
            mask_loss_mult=0.1,
            mst_mode="mst",
            mask_reso=(128, 128, 128),
            mask_log2_hash_size=21,
            mask_type="global",
            st_mlp_mode="shared",
            num_proposal_iterations=1,
            use_loss_static=False,
            render_static=False,
            interp="linear",
            mask_init_mean=-1,
            proposal_net_args_list=[
                {
                    "hidden_dim": 16,
                    "log2_hashmap_size_spatial": 18,
                    "log2_hashmap_size_temporal": 18,
                    "num_levels": 5,
                    "max_res": (128, 128, 128, 150),
                    "base_res": (16, 16, 16, 15),
                    "use_linear": False,
                    "mode": "mst",
                    "mask_reso": (64, 64, 64),
                    "mask_log2_hash_size": 18,
                    "mask_type": "global",
                    "st_mlp_mode": "shared",
                    "interp": "linear",
                    "mask_init_mean": -1,
                },
            ],
        ),
    ),
    optimizers={
        "proposal_networks": {
            "optimizer": AdamOptimizerConfig(lr=1e-2, eps=1e-15),
            "scheduler": ExponentialDecaySchedulerConfig(lr_final=1e-4, max_steps=40000),
        },
        "fields": {
            "optimizer": AdamOptimizerConfig(lr=1e-2, eps=1e-15),
            "scheduler": ExponentialDecaySchedulerConfig(lr_final=1e-4, max_steps=40000),
        },
    },
)

method_configs["base_it40000_base_64"] = SpaceTimeHashingTrainerConfig(
    method_name="Spatial_Time_Hashing_With_Base",
    steps_per_eval_batch=1000,
    steps_per_save=50000,
    max_num_iterations=40000,
    mixed_precision=True,
    log_gradients=True,
    pipeline=SpaceTimePipelineConfig(
        datamanager=SpaceTimeDataManagerConfig(
            dataparser=VideoDataParserConfig(
                data=Path("/data/machine/data/flame_salmon_videos_2"),
                # data=Path("/data/machine/data/fit/videos_2"),
                downscale_factor=2,
                # scale_factor=1.0 / 2.0,
                scale_factor=0.5,
            ),
            train_num_rays_per_batch=16384,
            # eval_num_rays_per_batch=32768,
            camera_optimizer=CameraOptimizerConfig(mode="off"),
            use_uint8=True,
            # use_stratified_pixel_sampler=True,
            spatial_temporal_sampler="st",
            # n_time_for_dynamic=3,
            n_time_for_dynamic=lambda x: 1 if x < 1000 else 1 + 5 * np.sin((x - 1000) * np.pi / (2 * (40000 - 1000))),
            static_dynamic_sampling_ratio=50,
            static_dynamic_sampling_ratio_end=10,
            static_ratio_decay_total_steps=500000,
        ),
        model=DSpaceTimeHashingModelConfig(
            # distortion_loss_mult=0.0,
            max_res=(2048, 2048, 2048, 300),
            base_res=(16, 16, 16, 15),
            # num_proposal_samples_per_ray=(256, 96),
            num_proposal_samples_per_ray=(128,),
            num_nerf_samples_per_ray=64,
            proposal_weights_anneal_max_num_iters=5000,
            # proposal_weights_anneal_slope = 10.0,
            log2_hashmap_size_spatial=19,
            log2_hashmap_size_temporal=19,
            eval_num_rays_per_chunk=32768,
            mask_loss_mult=0.1,
            mst_mode="mt",
            mask_reso=(128, 128, 128),
            mask_log2_hash_size=21,
            mask_type="global",
            st_mlp_mode="shared",
            num_proposal_iterations=1,
            use_loss_static=False,
            render_static=False,
            interp="linear",
            mask_init_mean=-1,
            proposal_net_args_list=[
                {
                    "hidden_dim": 16,
                    "log2_hashmap_size_spatial": 18,
                    "log2_hashmap_size_temporal": 18,
                    "num_levels": 5,
                    "max_res": (128, 128, 128, 150),
                    "base_res": (16, 16, 16, 15),
                    "use_linear": False,
                    "mode": "mt",
                    "mask_reso": (64, 64, 64),
                    "mask_log2_hash_size": 18,
                    "mask_type": "global",
                    "st_mlp_mode": "shared",
                    "interp": "linear",
                    "mask_init_mean": -1,
                },
            ],
        ),
    ),
    optimizers={
        "proposal_networks": {
            "optimizer": AdamOptimizerConfig(lr=1e-2, eps=1e-15),
            "scheduler": ExponentialDecaySchedulerConfig(lr_final=1e-4, max_steps=40000),
        },
        "fields": {
            "optimizer": AdamOptimizerConfig(lr=1e-2, eps=1e-15),
            "scheduler": ExponentialDecaySchedulerConfig(lr_final=1e-4, max_steps=40000),
        },
    },
)

method_configs["base_it40000_base_32"] = SpaceTimeHashingTrainerConfig(
    method_name="Spatial_Time_Hashing_With_Base",
    steps_per_eval_batch=1000,
    steps_per_save=50000,
    max_num_iterations=40000,
    mixed_precision=True,
    log_gradients=True,
    pipeline=SpaceTimePipelineConfig(
        datamanager=SpaceTimeDataManagerConfig(
            dataparser=VideoDataParserConfig(
                data=Path("/data/machine/data/flame_salmon_videos_2"),
                # data=Path("/data/machine/data/fit/videos_2"),
                downscale_factor=2,
                # scale_factor=1.0 / 2.0,
                scale_factor=0.5,
            ),
            train_num_rays_per_batch=16384,
            # eval_num_rays_per_batch=32768,
            camera_optimizer=CameraOptimizerConfig(mode="off"),
            use_uint8=True,
            # use_stratified_pixel_sampler=True,
            spatial_temporal_sampler="st",
            # n_time_for_dynamic=3,
            n_time_for_dynamic=lambda x: 1 if x < 1000 else 1 + 5 * np.sin((x - 1000) * np.pi / (2 * (40000 - 1000))),
            static_dynamic_sampling_ratio=50,
            static_dynamic_sampling_ratio_end=10,
            static_ratio_decay_total_steps=500000,
        ),
        model=DSpaceTimeHashingModelConfig(
            # distortion_loss_mult=0.0,
            max_res=(2048, 2048, 2048, 300),
            base_res=(16, 16, 16, 15),
            # num_proposal_samples_per_ray=(256, 96),
            num_proposal_samples_per_ray=(128,),
            num_nerf_samples_per_ray=32,
            proposal_weights_anneal_max_num_iters=5000,
            # proposal_weights_anneal_slope = 10.0,
            log2_hashmap_size_spatial=19,
            log2_hashmap_size_temporal=19,
            eval_num_rays_per_chunk=32768,
            mask_loss_mult=0.1,
            mst_mode="mt",
            mask_reso=(128, 128, 128),
            mask_log2_hash_size=21,
            mask_type="global",
            st_mlp_mode="shared",
            num_proposal_iterations=1,
            use_loss_static=False,
            render_static=False,
            interp="linear",
            mask_init_mean=-1,
            proposal_net_args_list=[
                {
                    "hidden_dim": 16,
                    "log2_hashmap_size_spatial": 18,
                    "log2_hashmap_size_temporal": 18,
                    "num_levels": 5,
                    "max_res": (128, 128, 128, 150),
                    "base_res": (16, 16, 16, 15),
                    "use_linear": False,
                    "mode": "mt",
                    "mask_reso": (64, 64, 64),
                    "mask_log2_hash_size": 18,
                    "mask_type": "global",
                    "st_mlp_mode": "shared",
                    "interp": "linear",
                    "mask_init_mean": -1,
                },
            ],
        ),
    ),
    optimizers={
        "proposal_networks": {
            "optimizer": AdamOptimizerConfig(lr=1e-2, eps=1e-15),
            "scheduler": ExponentialDecaySchedulerConfig(lr_final=5e-4, max_steps=40000),
        },
        "fields": {
            "optimizer": AdamOptimizerConfig(lr=1e-2, eps=1e-15),
            "scheduler": ExponentialDecaySchedulerConfig(lr_final=5e-4, max_steps=40000),
        },
    },
)

method_configs["base_it40000_base_independent"] = SpaceTimeHashingTrainerConfig(
    method_name="Spatial_Time_Hashing_With_Base",
    steps_per_eval_batch=1000,
    steps_per_save=50000,
    max_num_iterations=40000,
    mixed_precision=True,
    log_gradients=True,
    pipeline=SpaceTimePipelineConfig(
        datamanager=SpaceTimeDataManagerConfig(
            dataparser=VideoDataParserConfig(
                data=Path("/data/machine/data/flame_salmon_videos_2"),
                # data=Path("/data/machine/data/fit/videos_2"),
                downscale_factor=2,
                # scale_factor=1.0 / 2.0,
                scale_factor=0.5,
            ),
            train_num_rays_per_batch=16384,
            # eval_num_rays_per_batch=32768,
            camera_optimizer=CameraOptimizerConfig(mode="off"),
            use_uint8=True,
            # use_stratified_pixel_sampler=True,
            spatial_temporal_sampler="st",
            # n_time_for_dynamic=3,
            n_time_for_dynamic=lambda x: 1 if x < 1000 else 1 + 5 * np.sin((x - 1000) * np.pi / (2 * (40000 - 1000))),
            static_dynamic_sampling_ratio=50,
            static_dynamic_sampling_ratio_end=10,
            static_ratio_decay_total_steps=500000,
        ),
        model=DSpaceTimeHashingModelConfig(
            # distortion_loss_mult=0.0,
            max_res=(2048, 2048, 2048, 300),
            base_res=(16, 16, 16, 15),
            # num_proposal_samples_per_ray=(256, 96),
            num_proposal_samples_per_ray=(128,),
            num_nerf_samples_per_ray=48,
            proposal_weights_anneal_max_num_iters=5000,
            # proposal_weights_anneal_slope = 10.0,
            log2_hashmap_size_spatial=19,
            log2_hashmap_size_temporal=19,
            eval_num_rays_per_chunk=32768,
            mask_loss_mult=0.1,
            mst_mode="mt",
            mask_reso=(128, 128, 128),
            mask_log2_hash_size=21,
            mask_type="global",
            st_mlp_mode="independent",
            num_proposal_iterations=1,
            use_loss_static=False,
            render_static=False,
            interp="linear",
            mask_init_mean=-1,
            proposal_net_args_list=[
                {
                    "hidden_dim": 16,
                    "log2_hashmap_size_spatial": 18,
                    "log2_hashmap_size_temporal": 18,
                    "num_levels": 5,
                    "max_res": (128, 128, 128, 150),
                    "base_res": (16, 16, 16, 15),
                    "use_linear": False,
                    "mode": "mt",
                    "mask_reso": (64, 64, 64),
                    "mask_log2_hash_size": 18,
                    "mask_type": "global",
                    "st_mlp_mode": "independent",
                    "interp": "linear",
                    "mask_init_mean": -1,
                },
            ],
        ),
    ),
    optimizers={
        "proposal_networks": {
            "optimizer": AdamOptimizerConfig(lr=1e-2, eps=1e-15),
            "scheduler": ExponentialDecaySchedulerConfig(lr_final=1e-4, max_steps=40000),
        },
        "fields": {
            "optimizer": AdamOptimizerConfig(lr=1e-2, eps=1e-15),
            "scheduler": ExponentialDecaySchedulerConfig(lr_final=1e-4, max_steps=40000),
        },
    },
)

method_configs["base_it40000_base_withtimequery"] = copy.deepcopy(method_configs["base_it40000_base"])
method_configs["base_it40000_base_withtimequery"].pipeline.model.mask_type = "global_timequery"
method_configs["base_it40000_base_withtimequery"].pipeline.model.proposal_net_args_list[0][
    "mask_type"
] = "global_timequery"

method_configs["base_it40000_base_dist1"] = copy.deepcopy(method_configs["base_it40000_base"])
method_configs["base_it40000_base_dist1"].pipeline.model.distortion_loss_mult = 0.001

method_configs["base_it40000_base_dist2"] = copy.deepcopy(method_configs["base_it40000_base"])
method_configs["base_it40000_base_dist2"].pipeline.model.distortion_loss_mult = 0.003

method_configs["base_it40000_base_dist3"] = copy.deepcopy(method_configs["base_it40000_base"])
method_configs["base_it40000_base_dist3"].pipeline.model.distortion_loss_mult = 0.006

method_configs["base_it40000_base_dist4"] = copy.deepcopy(method_configs["base_it40000_base"])
method_configs["base_it40000_base_dist4"].pipeline.model.distortion_loss_mult = 0.003
method_configs["base_it40000_base_dist4"].pipeline.model.distortion_loss_mult_end = 0.002
method_configs["base_it40000_base_dist4"].pipeline.model.distortion_loss_mult_decay = 20000

method_configs["base_it40000_base_dist4_sparse1"] = copy.deepcopy(method_configs["base_it40000_base_dist4"])
method_configs["base_it40000_base_dist4_sparse1"].pipeline.model.sparse_loss_mult = 1e-5
method_configs["base_it40000_base_dist4_sparse1"].pipeline.model.sparse_loss_mult_end = 0
method_configs["base_it40000_base_dist4_sparse1"].pipeline.model.sparse_loss_mult_decay = 15000

method_configs["base_it40000_base_dist4_sparse1_sharpen2"] = copy.deepcopy(
    method_configs["base_it40000_base_dist4_sparse1"]
)
method_configs["base_it40000_base_dist4_sparse1_sharpen2"].pipeline.model.dist_sharpen = 2

method_configs["base_it40000_base_sparse1"] = copy.deepcopy(method_configs["base_it40000_base"])
method_configs["base_it40000_base_sparse1"].pipeline.model.sparse_loss_mult = 1e-5

method_configs["base_it40000_base_scale0.125"] = copy.deepcopy(method_configs["base_it40000_base"])
method_configs["base_it40000_base_scale0.125"].pipeline.datamanager.dataparser.scale_factor = 0.125

method_configs["base_it40000_base_scale0.25"] = copy.deepcopy(method_configs["base_it40000_base"])
method_configs["base_it40000_base_scale0.25"].pipeline.datamanager.dataparser.scale_factor = 0.25

method_configs["base_it40000_base_scale1_md_4"] = copy.deepcopy(method_configs["base_it40000_base"])
method_configs["base_it40000_base_scale1_md_4"].pipeline.datamanager.dataparser.scale_factor = 1.0
method_configs["base_it40000_base_scale1_md_4"].pipeline.model.middle_distance = 4
method_configs["base_it40000_base_scale1_md_4"].steps_per_eval_image = 5000

method_configs["base_it40000_base_scale0.5_md_1.5"] = copy.deepcopy(method_configs["base_it40000_base"])
method_configs["base_it40000_base_scale0.5_md_1.5"].pipeline.datamanager.dataparser.scale_factor = 0.5
method_configs["base_it40000_base_scale0.5_md_1.5"].pipeline.model.middle_distance = 1.5
method_configs["base_it40000_base_scale0.5_md_1.5"].steps_per_eval_image = 2000

method_configs["base_it40000_base_scale0.5_md_1.5_constant"] = copy.deepcopy(
    method_configs["base_it40000_base_scale0.5_md_1.5"]
)

method_configs["base_it40000_base_scale0.5_md_1.5_constant"].pipeline.datamanager.n_time_for_dynamic = lambda x: 1
method_configs["base_it40000_base_scale0.5_md_1.5_constant"].pipeline.datamanager.static_dynamic_sampling_ratio_end = 10

method_configs["base_it40000_base_scale0.5_md_1.5_constant_cosine"] = copy.deepcopy(
    method_configs["base_it40000_base_scale0.5_md_1.5_constant"]
)
method_configs["base_it40000_base_scale0.5_md_1.5_constant_cosine"].optimizers["proposal_networks"][
    "scheduler"
] = CosineDecaySchedulerConfig(lr_final=1e-4, max_steps=40000)
method_configs["base_it40000_base_scale0.5_md_1.5_constant_cosine"].optimizers["fields"][
    "scheduler"
] = CosineDecaySchedulerConfig(lr_final=1e-4, max_steps=40000)

method_configs["base_it40000_base_distsharpen2"] = SpaceTimeHashingTrainerConfig(
    method_name="Spatial_Time_Hashing_With_Base",
    steps_per_eval_batch=1000,
    steps_per_save=50000,
    max_num_iterations=40000,
    mixed_precision=True,
    log_gradients=True,
    pipeline=SpaceTimePipelineConfig(
        datamanager=SpaceTimeDataManagerConfig(
            dataparser=VideoDataParserConfig(
                data=Path("/data/machine/data/flame_salmon_videos_2"),
                # data=Path("/data/machine/data/fit/videos_2"),
                downscale_factor=2,
                # scale_factor=1.0 / 2.0,
                scale_factor=0.5,
            ),
            train_num_rays_per_batch=16384,
            # eval_num_rays_per_batch=32768,
            camera_optimizer=CameraOptimizerConfig(mode="off"),
            use_uint8=True,
            # use_stratified_pixel_sampler=True,
            spatial_temporal_sampler="st",
            # n_time_for_dynamic=3,
            n_time_for_dynamic=lambda x: 1 if x < 1000 else 1 + 5 * np.sin((x - 1000) * np.pi / (2 * (40000 - 1000))),
            static_dynamic_sampling_ratio=50,
            static_dynamic_sampling_ratio_end=10,
            static_ratio_decay_total_steps=500000,
        ),
        model=DSpaceTimeHashingModelConfig(
            # distortion_loss_mult=0.0,
            max_res=(2048, 2048, 2048, 300),
            base_res=(16, 16, 16, 15),
            # num_proposal_samples_per_ray=(256, 96),
            num_proposal_samples_per_ray=(128,),
            num_nerf_samples_per_ray=48,
            proposal_weights_anneal_max_num_iters=5000,
            # proposal_weights_anneal_slope = 10.0,
            log2_hashmap_size_spatial=19,
            log2_hashmap_size_temporal=19,
            eval_num_rays_per_chunk=32768,
            mask_loss_mult=0.1,
            mst_mode="mt",
            mask_reso=(128, 128, 128),
            mask_log2_hash_size=21,
            mask_type="global",
            st_mlp_mode="shared",
            num_proposal_iterations=1,
            use_loss_static=False,
            render_static=False,
            interp="linear",
            mask_init_mean=-1,
            dist_sharpen=2.0,
            proposal_net_args_list=[
                {
                    "hidden_dim": 16,
                    "log2_hashmap_size_spatial": 18,
                    "log2_hashmap_size_temporal": 18,
                    "num_levels": 5,
                    "max_res": (128, 128, 128, 150),
                    "base_res": (16, 16, 16, 15),
                    "use_linear": False,
                    "mode": "mt",
                    "mask_reso": (64, 64, 64),
                    "mask_log2_hash_size": 18,
                    "mask_type": "global",
                    "st_mlp_mode": "shared",
                    "interp": "linear",
                    "mask_init_mean": -1,
                },
            ],
        ),
    ),
    optimizers={
        "proposal_networks": {
            "optimizer": AdamOptimizerConfig(lr=1e-2, eps=1e-15),
            "scheduler": ExponentialDecaySchedulerConfig(lr_final=1e-4, max_steps=40000),
        },
        "fields": {
            "optimizer": AdamOptimizerConfig(lr=1e-2, eps=1e-15),
            "scheduler": ExponentialDecaySchedulerConfig(lr_final=1e-4, max_steps=40000),
        },
    },
)

method_configs["base_it40000_base_distsharpen5"] = copy.deepcopy(method_configs["base_it40000_base_distsharpen2"])
method_configs["base_it40000_base_distsharpen5"].pipeline.model.dist_sharpen = 5.0

method_configs["base_it40000_base_distsharpen10"] = copy.deepcopy(method_configs["base_it40000_base_distsharpen2"])
method_configs["base_it40000_base_distsharpen10"].pipeline.model.dist_sharpen = 10.0

method_configs["base_it40000_base_llffpose"] = SpaceTimeHashingTrainerConfig(
    method_name="Spatial_Time_Hashing_With_Base",
    steps_per_eval_batch=1000,
    steps_per_save=50000,
    max_num_iterations=40000,
    mixed_precision=True,
    log_gradients=True,
    pipeline=SpaceTimePipelineConfig(
        datamanager=SpaceTimeDataManagerConfig(
            dataparser=VideoDataParserConfig(
                data=Path("/data/machine/data/flame_salmon_videos_2"),
                # data=Path("/data/machine/data/fit/videos_2"),
                downscale_factor=2,
                # scale_factor=1.0 / 2.0,
                scale_factor=0.5,
                use_llff_poses=True,
            ),
            train_num_rays_per_batch=16384,
            # eval_num_rays_per_batch=32768,
            camera_optimizer=CameraOptimizerConfig(mode="off"),
            use_uint8=True,
            # use_stratified_pixel_sampler=True,
            spatial_temporal_sampler="st",
            # n_time_for_dynamic=3,
            n_time_for_dynamic=lambda x: 1 if x < 1000 else 1 + 5 * np.sin((x - 1000) * np.pi / (2 * (40000 - 1000))),
            static_dynamic_sampling_ratio=50,
            static_dynamic_sampling_ratio_end=10,
            static_ratio_decay_total_steps=500000,
        ),
        model=DSpaceTimeHashingModelConfig(
            # distortion_loss_mult=0.0,
            max_res=(2048, 2048, 2048, 300),
            base_res=(16, 16, 16, 15),
            # num_proposal_samples_per_ray=(256, 96),
            num_proposal_samples_per_ray=(128,),
            num_nerf_samples_per_ray=48,
            proposal_weights_anneal_max_num_iters=5000,
            # proposal_weights_anneal_slope = 10.0,
            log2_hashmap_size_spatial=19,
            log2_hashmap_size_temporal=19,
            eval_num_rays_per_chunk=32768,
            mask_loss_mult=0.1,
            mst_mode="mt",
            mask_reso=(128, 128, 128),
            mask_log2_hash_size=21,
            mask_type="global",
            st_mlp_mode="shared",
            num_proposal_iterations=1,
            use_loss_static=False,
            render_static=False,
            interp="linear",
            mask_init_mean=-1,
            proposal_net_args_list=[
                {
                    "hidden_dim": 16,
                    "log2_hashmap_size_spatial": 18,
                    "log2_hashmap_size_temporal": 18,
                    "num_levels": 5,
                    "max_res": (128, 128, 128, 150),
                    "base_res": (16, 16, 16, 15),
                    "use_linear": False,
                    "mode": "mt",
                    "mask_reso": (64, 64, 64),
                    "mask_log2_hash_size": 18,
                    "mask_type": "global",
                    "st_mlp_mode": "shared",
                    "interp": "linear",
                    "mask_init_mean": -1,
                },
            ],
        ),
    ),
    optimizers={
        "proposal_networks": {
            "optimizer": AdamOptimizerConfig(lr=1e-2, eps=1e-15),
            "scheduler": ExponentialDecaySchedulerConfig(lr_final=1e-4, max_steps=40000),
        },
        "fields": {
            "optimizer": AdamOptimizerConfig(lr=1e-2, eps=1e-15),
            "scheduler": ExponentialDecaySchedulerConfig(lr_final=1e-4, max_steps=40000),
        },
    },
)

method_configs["base_it40000_base_720p_area"] = SpaceTimeHashingTrainerConfig(
    method_name="Spatial_Time_Hashing_With_Base",
    steps_per_eval_batch=1000,
    steps_per_save=50000,
    max_num_iterations=40000,
    mixed_precision=True,
    log_gradients=True,
    pipeline=SpaceTimePipelineConfig(
        datamanager=SpaceTimeDataManagerConfig(
            dataparser=VideoDataParserConfig(
                data=Path("/data/machine/data/flame_salmon_videos_720p_area/"),
                # data=Path("/data/machine/data/fit/videos_2"),
                downscale_factor=2,
                # scale_factor=1.0 / 2.0,
                scale_factor=0.5,
            ),
            train_num_rays_per_batch=16384,
            # eval_num_rays_per_batch=32768,
            camera_optimizer=CameraOptimizerConfig(mode="off"),
            use_uint8=True,
            # use_stratified_pixel_sampler=True,
            spatial_temporal_sampler="st",
            # n_time_for_dynamic=3,
            n_time_for_dynamic=lambda x: 1 if x < 1000 else 1 + 5 * np.sin((x - 1000) * np.pi / (2 * (40000 - 1000))),
            static_dynamic_sampling_ratio=50,
            static_dynamic_sampling_ratio_end=10,
            static_ratio_decay_total_steps=500000,
        ),
        model=DSpaceTimeHashingModelConfig(
            # distortion_loss_mult=0.0,
            max_res=(2048, 2048, 2048, 300),
            base_res=(16, 16, 16, 15),
            # num_proposal_samples_per_ray=(256, 96),
            num_proposal_samples_per_ray=(128,),
            num_nerf_samples_per_ray=48,
            proposal_weights_anneal_max_num_iters=5000,
            # proposal_weights_anneal_slope = 10.0,
            log2_hashmap_size_spatial=19,
            log2_hashmap_size_temporal=19,
            eval_num_rays_per_chunk=32768,
            mask_loss_mult=0.1,
            mst_mode="mt",
            mask_reso=(128, 128, 128),
            mask_log2_hash_size=21,
            mask_type="global",
            st_mlp_mode="shared",
            num_proposal_iterations=1,
            use_loss_static=False,
            render_static=False,
            interp="linear",
            mask_init_mean=-1,
            proposal_net_args_list=[
                {
                    "hidden_dim": 16,
                    "log2_hashmap_size_spatial": 18,
                    "log2_hashmap_size_temporal": 18,
                    "num_levels": 5,
                    "max_res": (128, 128, 128, 150),
                    "base_res": (16, 16, 16, 15),
                    "use_linear": False,
                    "mode": "mt",
                    "mask_reso": (64, 64, 64),
                    "mask_log2_hash_size": 18,
                    "mask_type": "global",
                    "st_mlp_mode": "shared",
                    "interp": "linear",
                    "mask_init_mean": -1,
                },
            ],
        ),
    ),
    optimizers={
        "proposal_networks": {
            "optimizer": AdamOptimizerConfig(lr=1e-2, eps=1e-15),
            "scheduler": ExponentialDecaySchedulerConfig(lr_final=1e-4, max_steps=40000),
        },
        "fields": {
            "optimizer": AdamOptimizerConfig(lr=1e-2, eps=1e-15),
            "scheduler": ExponentialDecaySchedulerConfig(lr_final=1e-4, max_steps=40000),
        },
    },
)

method_configs["base_it40000_base_720p_linear"] = SpaceTimeHashingTrainerConfig(
    method_name="Spatial_Time_Hashing_With_Base",
    steps_per_eval_batch=1000,
    steps_per_save=50000,
    max_num_iterations=40000,
    mixed_precision=True,
    log_gradients=True,
    pipeline=SpaceTimePipelineConfig(
        datamanager=SpaceTimeDataManagerConfig(
            dataparser=VideoDataParserConfig(
                data=Path("/data/machine/data/flame_salmon_videos_720p_linear/"),
                # data=Path("/data/machine/data/fit/videos_2"),
                downscale_factor=2,
                # scale_factor=1.0 / 2.0,
                scale_factor=0.5,
            ),
            train_num_rays_per_batch=16384,
            # eval_num_rays_per_batch=32768,
            camera_optimizer=CameraOptimizerConfig(mode="off"),
            use_uint8=True,
            # use_stratified_pixel_sampler=True,
            spatial_temporal_sampler="st",
            # n_time_for_dynamic=3,
            n_time_for_dynamic=lambda x: 1 if x < 1000 else 1 + 5 * np.sin((x - 1000) * np.pi / (2 * (40000 - 1000))),
            static_dynamic_sampling_ratio=50,
            static_dynamic_sampling_ratio_end=10,
            static_ratio_decay_total_steps=500000,
        ),
        model=DSpaceTimeHashingModelConfig(
            # distortion_loss_mult=0.0,
            max_res=(2048, 2048, 2048, 300),
            base_res=(16, 16, 16, 15),
            # num_proposal_samples_per_ray=(256, 96),
            num_proposal_samples_per_ray=(128,),
            num_nerf_samples_per_ray=48,
            proposal_weights_anneal_max_num_iters=5000,
            # proposal_weights_anneal_slope = 10.0,
            log2_hashmap_size_spatial=19,
            log2_hashmap_size_temporal=19,
            eval_num_rays_per_chunk=32768,
            mask_loss_mult=0.1,
            mst_mode="mt",
            mask_reso=(128, 128, 128),
            mask_log2_hash_size=21,
            mask_type="global",
            st_mlp_mode="shared",
            num_proposal_iterations=1,
            use_loss_static=False,
            render_static=False,
            interp="linear",
            mask_init_mean=-1,
            proposal_net_args_list=[
                {
                    "hidden_dim": 16,
                    "log2_hashmap_size_spatial": 18,
                    "log2_hashmap_size_temporal": 18,
                    "num_levels": 5,
                    "max_res": (128, 128, 128, 150),
                    "base_res": (16, 16, 16, 15),
                    "use_linear": False,
                    "mode": "mt",
                    "mask_reso": (64, 64, 64),
                    "mask_log2_hash_size": 18,
                    "mask_type": "global",
                    "st_mlp_mode": "shared",
                    "interp": "linear",
                    "mask_init_mean": -1,
                },
            ],
        ),
    ),
    optimizers={
        "proposal_networks": {
            "optimizer": AdamOptimizerConfig(lr=1e-2, eps=1e-15),
            "scheduler": ExponentialDecaySchedulerConfig(lr_final=1e-4, max_steps=40000),
        },
        "fields": {
            "optimizer": AdamOptimizerConfig(lr=1e-2, eps=1e-15),
            "scheduler": ExponentialDecaySchedulerConfig(lr_final=1e-4, max_steps=40000),
        },
    },
)


method_configs["base_it40000_base_maskproposal_off"] = SpaceTimeHashingTrainerConfig(
    method_name="Spatial_Time_Hashing_With_Base",
    steps_per_eval_batch=1000,
    steps_per_save=50000,
    max_num_iterations=40000,
    mixed_precision=True,
    log_gradients=True,
    pipeline=SpaceTimePipelineConfig(
        datamanager=SpaceTimeDataManagerConfig(
            dataparser=VideoDataParserConfig(
                data=Path("/data/machine/data/flame_salmon_videos_2"),
                # data=Path("/data/machine/data/fit/videos_2"),
                downscale_factor=2,
                # scale_factor=1.0 / 2.0,
                scale_factor=0.5,
            ),
            train_num_rays_per_batch=16384,
            # eval_num_rays_per_batch=32768,
            camera_optimizer=CameraOptimizerConfig(mode="off"),
            use_uint8=True,
            # use_stratified_pixel_sampler=True,
            spatial_temporal_sampler="st",
            # n_time_for_dynamic=3,
            n_time_for_dynamic=lambda x: 1 if x < 1000 else 1 + 5 * np.sin((x - 1000) * np.pi / (2 * (40000 - 1000))),
            static_dynamic_sampling_ratio=50,
            static_dynamic_sampling_ratio_end=10,
            static_ratio_decay_total_steps=500000,
        ),
        model=DSpaceTimeHashingModelConfig(
            # distortion_loss_mult=0.0,
            max_res=(2048, 2048, 2048, 300),
            base_res=(16, 16, 16, 15),
            # num_proposal_samples_per_ray=(256, 96),
            num_proposal_samples_per_ray=(128,),
            num_nerf_samples_per_ray=48,
            proposal_weights_anneal_max_num_iters=5000,
            # proposal_weights_anneal_slope = 10.0,
            log2_hashmap_size_spatial=19,
            log2_hashmap_size_temporal=19,
            eval_num_rays_per_chunk=32768,
            mask_loss_mult=0.1,
            mask_loss_for_proposal=False,
            mst_mode="mt",
            mask_reso=(128, 128, 128),
            mask_log2_hash_size=21,
            mask_type="global",
            st_mlp_mode="shared",
            num_proposal_iterations=1,
            use_loss_static=False,
            render_static=False,
            interp="linear",
            mask_init_mean=-1,
            proposal_net_args_list=[
                {
                    "hidden_dim": 16,
                    "log2_hashmap_size_spatial": 18,
                    "log2_hashmap_size_temporal": 18,
                    "num_levels": 5,
                    "max_res": (128, 128, 128, 150),
                    "base_res": (16, 16, 16, 15),
                    "use_linear": False,
                    "mode": "mt",
                    "mask_reso": (64, 64, 64),
                    "mask_log2_hash_size": 18,
                    "mask_type": "global",
                    "st_mlp_mode": "shared",
                    "interp": "linear",
                    "mask_init_mean": -1,
                },
            ],
        ),
    ),
    optimizers={
        "proposal_networks": {
            "optimizer": AdamOptimizerConfig(lr=1e-2, eps=1e-15),
            "scheduler": ExponentialDecaySchedulerConfig(lr_final=1e-4, max_steps=40000),
        },
        "fields": {
            "optimizer": AdamOptimizerConfig(lr=1e-2, eps=1e-15),
            "scheduler": ExponentialDecaySchedulerConfig(lr_final=1e-4, max_steps=40000),
        },
    },
)


method_configs["base_it40000_base_maskproposal_off_tiny"] = SpaceTimeHashingTrainerConfig(
    method_name="Spatial_Time_Hashing_With_Base",
    steps_per_eval_batch=1000,
    steps_per_save=50000,
    max_num_iterations=40000,
    mixed_precision=True,
    log_gradients=True,
    pipeline=SpaceTimePipelineConfig(
        datamanager=SpaceTimeDataManagerConfig(
            dataparser=VideoDataParserConfig(
                data=Path("/data/machine/data/flame_salmon_videos_2"),
                # data=Path("/data/machine/data/fit/videos_2"),
                downscale_factor=2,
                # scale_factor=1.0 / 2.0,
                scale_factor=0.5,
            ),
            train_num_rays_per_batch=16384,
            # eval_num_rays_per_batch=32768,
            camera_optimizer=CameraOptimizerConfig(mode="off"),
            use_uint8=True,
            # use_stratified_pixel_sampler=True,
            spatial_temporal_sampler="st",
            # n_time_for_dynamic=3,
            n_time_for_dynamic=lambda x: 1 if x < 1000 else 1 + 5 * np.sin((x - 1000) * np.pi / (2 * (40000 - 1000))),
            static_dynamic_sampling_ratio=50,
            static_dynamic_sampling_ratio_end=10,
            static_ratio_decay_total_steps=500000,
        ),
        model=DSpaceTimeHashingModelConfig(
            # distortion_loss_mult=0.0,
            max_res=(2048, 2048, 2048, 300),
            base_res=(16, 16, 16, 15),
            # num_proposal_samples_per_ray=(256, 96),
            num_proposal_samples_per_ray=(48,),
            num_nerf_samples_per_ray=6,
            proposal_weights_anneal_max_num_iters=5000,
            # proposal_weights_anneal_slope = 10.0,
            log2_hashmap_size_spatial=19,
            log2_hashmap_size_temporal=19,
            eval_num_rays_per_chunk=32768,
            mask_loss_mult=0.1,
            mask_loss_for_proposal=False,
            mst_mode="mt",
            mask_reso=(128, 128, 128),
            mask_log2_hash_size=21,
            mask_type="global",
            st_mlp_mode="shared",
            num_proposal_iterations=1,
            use_loss_static=False,
            render_static=False,
            interp="linear",
            mask_init_mean=-1,
            proposal_net_args_list=[
                {
                    "hidden_dim": 16,
                    "log2_hashmap_size_spatial": 18,
                    "log2_hashmap_size_temporal": 18,
                    "num_levels": 5,
                    "max_res": (128, 128, 128, 150),
                    "base_res": (16, 16, 16, 15),
                    "use_linear": False,
                    "mode": "mt",
                    "mask_reso": (64, 64, 64),
                    "mask_log2_hash_size": 18,
                    "mask_type": "global",
                    "st_mlp_mode": "shared",
                    "interp": "linear",
                    "mask_init_mean": -1,
                },
            ],
        ),
    ),
    optimizers={
        "proposal_networks": {
            "optimizer": AdamOptimizerConfig(lr=1e-2, eps=1e-15),
            "scheduler": ExponentialDecaySchedulerConfig(lr_final=1e-4, max_steps=40000),
        },
        "fields": {
            "optimizer": AdamOptimizerConfig(lr=1e-2, eps=1e-15),
            "scheduler": ExponentialDecaySchedulerConfig(lr_final=1e-4, max_steps=40000),
        },
    },
)


method_configs["base_it40000_base_maskproposal_off_large"] = SpaceTimeHashingTrainerConfig(
    method_name="Spatial_Time_Hashing_With_Base",
    steps_per_eval_batch=1000,
    steps_per_save=50000,
    max_num_iterations=40000,
    mixed_precision=True,
    log_gradients=True,
    pipeline=SpaceTimePipelineConfig(
        datamanager=SpaceTimeDataManagerConfig(
            dataparser=VideoDataParserConfig(
                data=Path("/data/machine/data/flame_salmon_videos_2"),
                # data=Path("/data/machine/data/fit/videos_2"),
                downscale_factor=2,
                # scale_factor=1.0 / 2.0,
                scale_factor=0.5,
            ),
            train_num_rays_per_batch=16384,
            # eval_num_rays_per_batch=32768,
            camera_optimizer=CameraOptimizerConfig(mode="off"),
            use_uint8=True,
            # use_stratified_pixel_sampler=True,
            spatial_temporal_sampler="st",
            # n_time_for_dynamic=3,
            n_time_for_dynamic=lambda x: 1 if x < 1000 else 1 + 5 * np.sin((x - 1000) * np.pi / (2 * (40000 - 1000))),
            static_dynamic_sampling_ratio=50,
            static_dynamic_sampling_ratio_end=10,
            static_ratio_decay_total_steps=500000,
        ),
        model=DSpaceTimeHashingModelConfig(
            # distortion_loss_mult=0.0,
            max_res=(2048, 2048, 2048, 300),
            base_res=(16, 16, 16, 15),
            # num_proposal_samples_per_ray=(256, 96),
            num_proposal_samples_per_ray=(256,),
            num_nerf_samples_per_ray=96,
            proposal_weights_anneal_max_num_iters=5000,
            # proposal_weights_anneal_slope = 10.0,
            log2_hashmap_size_spatial=19,
            log2_hashmap_size_temporal=19,
            eval_num_rays_per_chunk=32768,
            mask_loss_mult=0.1,
            mask_loss_for_proposal=False,
            mst_mode="mt",
            mask_reso=(128, 128, 128),
            mask_log2_hash_size=21,
            mask_type="global",
            st_mlp_mode="shared",
            num_proposal_iterations=1,
            use_loss_static=False,
            render_static=False,
            interp="linear",
            mask_init_mean=-1,
            proposal_net_args_list=[
                {
                    "hidden_dim": 16,
                    "log2_hashmap_size_spatial": 18,
                    "log2_hashmap_size_temporal": 18,
                    "num_levels": 5,
                    "max_res": (128, 128, 128, 150),
                    "base_res": (16, 16, 16, 15),
                    "use_linear": False,
                    "mode": "mt",
                    "mask_reso": (64, 64, 64),
                    "mask_log2_hash_size": 18,
                    "mask_type": "global",
                    "st_mlp_mode": "shared",
                    "interp": "linear",
                    "mask_init_mean": -1,
                },
            ],
        ),
    ),
    optimizers={
        "proposal_networks": {
            "optimizer": AdamOptimizerConfig(lr=1e-2, eps=1e-15),
            "scheduler": ExponentialDecaySchedulerConfig(lr_final=1e-4, max_steps=40000),
        },
        "fields": {
            "optimizer": AdamOptimizerConfig(lr=1e-2, eps=1e-15),
            "scheduler": ExponentialDecaySchedulerConfig(lr_final=1e-4, max_steps=40000),
        },
    },
)

method_configs["base_it40000_base_uniform_far40"] = SpaceTimeHashingTrainerConfig(
    method_name="Spatial_Time_Hashing_With_Base",
    steps_per_eval_batch=1000,
    steps_per_save=50000,
    max_num_iterations=40000,
    mixed_precision=True,
    log_gradients=True,
    pipeline=SpaceTimePipelineConfig(
        datamanager=SpaceTimeDataManagerConfig(
            dataparser=VideoDataParserConfig(
                data=Path("/data/machine/data/flame_salmon_videos_2"),
                # data=Path("/data/machine/data/fit/videos_2"),
                downscale_factor=2,
                # scale_factor=1.0 / 2.0,
                scale_factor=0.5,
            ),
            train_num_rays_per_batch=16384,
            # eval_num_rays_per_batch=32768,
            camera_optimizer=CameraOptimizerConfig(mode="off"),
            use_uint8=True,
            # use_stratified_pixel_sampler=True,
            spatial_temporal_sampler="st",
            # n_time_for_dynamic=3,
            n_time_for_dynamic=lambda x: 1 if x < 1000 else 1 + 5 * np.sin((x - 1000) * np.pi / (2 * (40000 - 1000))),
            static_dynamic_sampling_ratio=50,
            static_dynamic_sampling_ratio_end=10,
            static_ratio_decay_total_steps=500000,
        ),
        model=DSpaceTimeHashingModelConfig(
            # distortion_loss_mult=0.0,
            max_res=(2048, 2048, 2048, 300),
            base_res=(16, 16, 16, 15),
            # num_proposal_samples_per_ray=(256, 96),
            proposal_initial_sampler="uniform",
            far_plane=40,
            num_proposal_samples_per_ray=(128,),
            num_nerf_samples_per_ray=48,
            proposal_weights_anneal_max_num_iters=5000,
            # proposal_weights_anneal_slope = 10.0,
            log2_hashmap_size_spatial=19,
            log2_hashmap_size_temporal=19,
            eval_num_rays_per_chunk=32768,
            mask_loss_mult=0.1,
            mst_mode="mt",
            mask_reso=(128, 128, 128),
            mask_log2_hash_size=21,
            mask_type="global",
            st_mlp_mode="shared",
            num_proposal_iterations=1,
            use_loss_static=False,
            render_static=False,
            interp="linear",
            mask_init_mean=-1,
            proposal_net_args_list=[
                {
                    "hidden_dim": 16,
                    "log2_hashmap_size_spatial": 18,
                    "log2_hashmap_size_temporal": 18,
                    "num_levels": 5,
                    "max_res": (128, 128, 128, 150),
                    "base_res": (16, 16, 16, 15),
                    "use_linear": False,
                    "mode": "mt",
                    "mask_reso": (64, 64, 64),
                    "mask_log2_hash_size": 18,
                    "mask_type": "global",
                    "st_mlp_mode": "shared",
                    "interp": "linear",
                    "mask_init_mean": -1,
                },
            ],
        ),
    ),
    optimizers={
        "proposal_networks": {
            "optimizer": AdamOptimizerConfig(lr=1e-2, eps=1e-15),
            "scheduler": ExponentialDecaySchedulerConfig(lr_final=1e-4, max_steps=40000),
        },
        "fields": {
            "optimizer": AdamOptimizerConfig(lr=1e-2, eps=1e-15),
            "scheduler": ExponentialDecaySchedulerConfig(lr_final=1e-4, max_steps=40000),
        },
    },
)


method_configs["base_it40000_base_cameraopt"] = SpaceTimeHashingTrainerConfig(
    method_name="Spatial_Time_Hashing_With_Base",
    steps_per_eval_batch=1000,
    steps_per_save=50000,
    max_num_iterations=40000,
    mixed_precision=True,
    log_gradients=True,
    pipeline=SpaceTimePipelineConfig(
        datamanager=SpaceTimeDataManagerConfig(
            dataparser=VideoDataParserConfig(
                data=Path("/data/machine/data/flame_salmon_videos_2"),
                # data=Path("/data/machine/data/fit/videos_2"),
                downscale_factor=2,
                # scale_factor=1.0 / 2.0,
                scale_factor=0.5,
            ),
            train_num_rays_per_batch=16384,
            # eval_num_rays_per_batch=32768,
            camera_optimizer=CameraOptimizerConfig(
                mode="SO3xR3", optimizer=AdamOptimizerConfig(lr=6e-4, eps=1e-8, weight_decay=1e-2)
            ),
            use_uint8=True,
            # use_stratified_pixel_sampler=True,
            spatial_temporal_sampler="st",
            # n_time_for_dynamic=3,
            n_time_for_dynamic=lambda x: 1 if x < 1000 else 1 + 5 * np.sin((x - 1000) * np.pi / (2 * (40000 - 1000))),
            static_dynamic_sampling_ratio=50,
            static_dynamic_sampling_ratio_end=10,
            static_ratio_decay_total_steps=500000,
        ),
        model=DSpaceTimeHashingModelConfig(
            # distortion_loss_mult=0.0,
            max_res=(2048, 2048, 2048, 300),
            base_res=(16, 16, 16, 15),
            # num_proposal_samples_per_ray=(256, 96),
            num_proposal_samples_per_ray=(128,),
            num_nerf_samples_per_ray=48,
            proposal_weights_anneal_max_num_iters=5000,
            # proposal_weights_anneal_slope = 10.0,
            log2_hashmap_size_spatial=19,
            log2_hashmap_size_temporal=19,
            eval_num_rays_per_chunk=32768,
            mask_loss_mult=0.1,
            mst_mode="mt",
            mask_reso=(128, 128, 128),
            mask_log2_hash_size=21,
            mask_type="global",
            st_mlp_mode="shared",
            num_proposal_iterations=1,
            use_loss_static=False,
            render_static=False,
            interp="linear",
            mask_init_mean=-1,
            proposal_net_args_list=[
                {
                    "hidden_dim": 16,
                    "log2_hashmap_size_spatial": 18,
                    "log2_hashmap_size_temporal": 18,
                    "num_levels": 5,
                    "max_res": (128, 128, 128, 150),
                    "base_res": (16, 16, 16, 15),
                    "use_linear": False,
                    "mode": "mt",
                    "mask_reso": (64, 64, 64),
                    "mask_log2_hash_size": 18,
                    "mask_type": "global",
                    "st_mlp_mode": "shared",
                    "interp": "linear",
                    "mask_init_mean": -1,
                },
            ],
        ),
    ),
    optimizers={
        "proposal_networks": {
            "optimizer": AdamOptimizerConfig(lr=1e-2, eps=1e-15),
            "scheduler": ExponentialDecaySchedulerConfig(lr_final=1e-4, max_steps=40000),
        },
        "fields": {
            "optimizer": AdamOptimizerConfig(lr=1e-2, eps=1e-15),
            "scheduler": ExponentialDecaySchedulerConfig(lr_final=1e-4, max_steps=40000),
        },
    },
)


method_configs["base_it20000_base_distsharpen2"] = SpaceTimeHashingTrainerConfig(
    method_name="Spatial_Time_Hashing_With_Base",
    steps_per_eval_batch=1000,
    steps_per_save=50000,
    max_num_iterations=20000,
    mixed_precision=True,
    log_gradients=True,
    pipeline=SpaceTimePipelineConfig(
        datamanager=SpaceTimeDataManagerConfig(
            dataparser=VideoDataParserConfig(
                data=Path("/data/machine/data/flame_salmon_videos_2"),
                # data=Path("/data/machine/data/fit/videos_2"),
                downscale_factor=2,
                # scale_factor=1.0 / 2.0,
                scale_factor=0.5,
            ),
            train_num_rays_per_batch=16384,
            # eval_num_rays_per_batch=32768,
            camera_optimizer=CameraOptimizerConfig(mode="off"),
            use_uint8=True,
            # use_stratified_pixel_sampler=True,
            spatial_temporal_sampler="st",
            # n_time_for_dynamic=3,
            n_time_for_dynamic=lambda x: 1 if x < 1000 else 1 + 5 * np.sin((x - 1000) * np.pi / (2 * (40000 - 1000))),
            static_dynamic_sampling_ratio=50,
            static_dynamic_sampling_ratio_end=10,
            static_ratio_decay_total_steps=500000,
        ),
        model=DSpaceTimeHashingModelConfig(
            # distortion_loss_mult=0.0,
            max_res=(2048, 2048, 2048, 300),
            base_res=(16, 16, 16, 15),
            # num_proposal_samples_per_ray=(256, 96),
            num_proposal_samples_per_ray=(128,),
            num_nerf_samples_per_ray=48,
            proposal_weights_anneal_max_num_iters=5000,
            # proposal_weights_anneal_slope = 10.0,
            log2_hashmap_size_spatial=19,
            log2_hashmap_size_temporal=19,
            eval_num_rays_per_chunk=32768,
            mask_loss_mult=0.1,
            mst_mode="mt",
            mask_reso=(128, 128, 128),
            mask_log2_hash_size=21,
            mask_type="global",
            st_mlp_mode="shared",
            num_proposal_iterations=1,
            use_loss_static=False,
            render_static=False,
            interp="linear",
            mask_init_mean=-1,
            dist_sharpen=2.0,
            proposal_net_args_list=[
                {
                    "hidden_dim": 16,
                    "log2_hashmap_size_spatial": 17,
                    "log2_hashmap_size_temporal": 17,
                    "num_levels": 5,
                    "max_res": (128, 128, 128, 150),
                    "base_res": (16, 16, 16, 15),
                    "use_linear": False,
                    "mode": "mt",
                    "mask_reso": (64, 64, 64),
                    "mask_log2_hash_size": 18,
                    "mask_type": "global",
                    "st_mlp_mode": "shared",
                    "interp": "linear",
                    "mask_init_mean": -1,
                },
            ],
        ),
    ),
    optimizers={
        "proposal_networks": {
            "optimizer": AdamOptimizerConfig(lr=1e-2, eps=1e-15),
            "scheduler": ExponentialDecaySchedulerConfig(lr_final=1e-4, max_steps=40000),
        },
        "fields": {
            "optimizer": AdamOptimizerConfig(lr=1e-2, eps=1e-15),
            "scheduler": ExponentialDecaySchedulerConfig(lr_final=1e-4, max_steps=40000),
        },
    },
)


method_configs["base_it20000_base_distsharpen0.7"] = SpaceTimeHashingTrainerConfig(
    method_name="Spatial_Time_Hashing_With_Base",
    steps_per_eval_batch=1000,
    steps_per_save=50000,
    max_num_iterations=20000,
    mixed_precision=True,
    log_gradients=True,
    pipeline=SpaceTimePipelineConfig(
        datamanager=SpaceTimeDataManagerConfig(
            dataparser=VideoDataParserConfig(
                data=Path("/data/machine/data/flame_salmon_videos_2"),
                # data=Path("/data/machine/data/fit/videos_2"),
                downscale_factor=2,
                # scale_factor=1.0 / 2.0,
                scale_factor=0.5,
            ),
            train_num_rays_per_batch=16384,
            # eval_num_rays_per_batch=32768,
            camera_optimizer=CameraOptimizerConfig(mode="off"),
            use_uint8=True,
            # use_stratified_pixel_sampler=True,
            spatial_temporal_sampler="st",
            # n_time_for_dynamic=3,
            n_time_for_dynamic=lambda x: 1 if x < 1000 else 1 + 5 * np.sin((x - 1000) * np.pi / (2 * (40000 - 1000))),
            static_dynamic_sampling_ratio=50,
            static_dynamic_sampling_ratio_end=10,
            static_ratio_decay_total_steps=500000,
        ),
        model=DSpaceTimeHashingModelConfig(
            # distortion_loss_mult=0.0,
            max_res=(2048, 2048, 2048, 300),
            base_res=(16, 16, 16, 15),
            # num_proposal_samples_per_ray=(256, 96),
            num_proposal_samples_per_ray=(128,),
            num_nerf_samples_per_ray=48,
            proposal_weights_anneal_max_num_iters=5000,
            # proposal_weights_anneal_slope = 10.0,
            log2_hashmap_size_spatial=19,
            log2_hashmap_size_temporal=19,
            eval_num_rays_per_chunk=32768,
            mask_loss_mult=0.1,
            mst_mode="mt",
            mask_reso=(128, 128, 128),
            mask_log2_hash_size=21,
            mask_type="global",
            st_mlp_mode="shared",
            num_proposal_iterations=1,
            use_loss_static=False,
            render_static=False,
            interp="linear",
            mask_init_mean=-1,
            dist_sharpen=0.7,
            proposal_net_args_list=[
                {
                    "hidden_dim": 16,
                    "log2_hashmap_size_spatial": 17,
                    "log2_hashmap_size_temporal": 17,
                    "num_levels": 5,
                    "max_res": (128, 128, 128, 150),
                    "base_res": (16, 16, 16, 15),
                    "use_linear": False,
                    "mode": "mt",
                    "mask_reso": (64, 64, 64),
                    "mask_log2_hash_size": 18,
                    "mask_type": "global",
                    "st_mlp_mode": "shared",
                    "interp": "linear",
                    "mask_init_mean": -1,
                },
            ],
        ),
    ),
    optimizers={
        "proposal_networks": {
            "optimizer": AdamOptimizerConfig(lr=1e-2, eps=1e-15),
            "scheduler": ExponentialDecaySchedulerConfig(lr_final=1e-4, max_steps=40000),
        },
        "fields": {
            "optimizer": AdamOptimizerConfig(lr=1e-2, eps=1e-15),
            "scheduler": ExponentialDecaySchedulerConfig(lr_final=1e-4, max_steps=40000),
        },
    },
)
